Học tập liên kết mô hình lớn hiệu quả với ứng dụng khách thông qua Federation_select và tập hợp thưa thớt

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Đây hướng dẫn cho thấy cách TFF có thể được sử dụng để đào tạo một mô hình rất lớn, nơi mỗi thiết bị khách hàng chỉ tải và cập nhật một phần nhỏ của các mô hình, sử dụng tff.federated_select và tập hợp thưa thớt. Trong khi hướng dẫn này là khá khép kín, các tff.federated_select hướng dẫntùy chỉnh FL thuật toán hướng dẫn cung cấp giới thiệu tốt để một số kỹ thuật sử dụng ở đây.

Cụ thể, trong hướng dẫn này, chúng tôi xem xét hồi quy logistic để phân loại nhiều nhãn, dự đoán "thẻ" nào được liên kết với chuỗi văn bản dựa trên đại diện tính năng nhiều từ. Quan trọng hơn, thông tin liên lạc và tính toán client-side chi phí được kiểm soát bởi một hằng số cố định ( MAX_TOKENS_SELECTED_PER_CLIENT ), và không mở rộng với quy mô vốn từ vựng chung, mà có thể là rất lớn trong môi trường thực tế.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import itertools
import numpy as np

from typing import Callable, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff
tff.backends.native.set_local_python_execution_context()

Mỗi khách hàng sẽ federated_select các hàng của trọng số mô hình trong ít nhất này nhiều thẻ độc đáo. Đây phía trên tiếp giáp với kích thước của mô hình địa phương của khách hàng và số lượng máy chủ -> khách hàng ( federated_select ) và khách hàng -> Máy chủ (federated_aggregate ) truyền thông thực hiện.

Hướng dẫn này sẽ vẫn chạy chính xác ngay cả khi bạn đặt giá trị này nhỏ là 1 (đảm bảo không phải tất cả các mã thông báo từ mỗi khách hàng đều được chọn) hoặc thành một giá trị lớn, mặc dù sự hội tụ mô hình có thể bị ảnh hưởng.

MAX_TOKENS_SELECTED_PER_CLIENT = 6

Chúng tôi cũng xác định một vài hằng số cho nhiều kiểu khác nhau. Đối với colab này, một mã thông báo là một định danh số nguyên cho một từ cụ thể sau khi phân tích các dữ liệu.

# There are some constraints on types
# here that will require some explicit type conversions:
#    - `tff.federated_select` requires int32
#    - `tf.SparseTensor` requires int64 indices.
TOKEN_DTYPE = tf.int64
SELECT_KEY_DTYPE = tf.int32

# Type for counts of token occurences.
TOKEN_COUNT_DTYPE = tf.int32

# A sparse feature vector can be thought of as a map
# from TOKEN_DTYPE to FEATURE_DTYPE. 
# Our features are {0, 1} indicators, so we could potentially
# use tf.int8 as an optimization.
FEATURE_DTYPE = tf.int32

Thiết lập vấn đề: Tập dữ liệu và Mô hình

Chúng tôi xây dựng một bộ dữ liệu đồ chơi nhỏ để dễ dàng thử nghiệm trong hướng dẫn này. Tuy nhiên, định dạng của các tập dữ liệu tương thích với Federated StackOverflow , và tiền xử lýkiến trúc mô hình được thông qua từ vấn đề thẻ dự đoán StackOverflow của thích ứng Federated Tối ưu hóa .

Phân tích cú pháp và xử lý trước tập dữ liệu

NUM_OOV_BUCKETS = 1

BatchType = collections.namedtuple('BatchType', ['tokens', 'tags'])

def build_to_ids_fn(word_vocab: List[str],
                    tag_vocab: List[str]) -> Callable[[tf.Tensor], tf.Tensor]:
  """Constructs a function mapping examples to sequences of token indices."""
  word_table_values = np.arange(len(word_vocab), dtype=np.int64)
  word_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  tag_table_values = np.arange(len(tag_vocab), dtype=np.int64)
  tag_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  def to_ids(example):
    """Converts a Stack Overflow example to a bag-of-words/tags format."""
    sentence = tf.strings.join([example['tokens'], example['title']],
                               separator=' ')

    # We represent that label (output tags) densely.
    raw_tags = example['tags']
    tags = tf.strings.split(raw_tags, sep='|')
    tags = tag_table.lookup(tags)
    tags, _ = tf.unique(tags)
    tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS)
    tags = tf.reduce_max(tags, axis=0)

    # We represent the features as a SparseTensor of {0, 1}s.
    words = tf.strings.split(sentence)
    tokens = word_table.lookup(words)
    tokens, _ = tf.unique(tokens)
    # Note:  We could choose to use the word counts as the feature vector
    # instead of just {0, 1} values (see tf.unique_with_counts).
    tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1))
    tokens_st = tf.SparseTensor(
        tokens,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,))
    tokens_st = tf.sparse.reorder(tokens_st)

    return BatchType(tokens_st, tags)

  return to_ids
def build_preprocess_fn(word_vocab, tag_vocab):

  @tf.function
  def preprocess_fn(dataset):
    to_ids = build_to_ids_fn(word_vocab, tag_vocab)
    # We *don't* shuffle in order to make this colab deterministic for
    # easier testing and reproducibility.
    # But real-world training should use `.shuffle()`.
    return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  return preprocess_fn

Một tập dữ liệu đồ chơi nhỏ

Chúng tôi xây dựng một bộ dữ liệu đồ chơi nhỏ với vốn từ vựng toàn cầu gồm 12 từ và 3 khách hàng. Ví dụ nhỏ này rất hữu ích để thử nghiệm các trường hợp cạnh (ví dụ, chúng ta có hai khách hàng có ít hơn MAX_TOKENS_SELECTED_PER_CLIENT = 6 thẻ riêng biệt, và một với hơn) và phát triển các mã.

Tuy nhiên, các trường hợp sử dụng trong thế giới thực của cách tiếp cận này sẽ là các kho từ vựng toàn cầu từ 10 triệu trở lên, với khoảng 1000 mã thông báo riêng biệt xuất hiện trên mỗi khách hàng. Bởi vì định dạng của dữ liệu là như nhau, phần mở rộng đến các vấn đề nền tảng thử nghiệm thực tế hơn, ví dụ như tff.simulation.datasets.stackoverflow.load_data() dữ liệu, nên đơn giản.

Đầu tiên, chúng tôi xác định từ của chúng tôi và gắn thẻ các từ vựng.

# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS

# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

Bây giờ, chúng tôi tạo 3 máy khách với bộ dữ liệu cục bộ nhỏ. Nếu bạn đang chạy hướng dẫn này trong chuyên mục, có thể hữu ích khi sử dụng tính năng "ô nhân bản trong tab" để ghim ô này và đầu ra của nó nhằm diễn giải / kiểm tra kết quả đầu ra của các hàm được phát triển bên dưới.

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB)


def make_dataset(raw):
  d = tf.data.Dataset.from_tensor_slices(
      # Matches the StackOverflow formatting
      collections.OrderedDict(
          tokens=tf.constant([t[0] for t in raw]),
          tags=tf.constant([t[1] for t in raw]),
          title=['' for _ in raw]))
  d = preprocess_fn(d)
  return d


# 4 distinct tokens
CLIENT1_DATASET = make_dataset([
    ('apple orange apple orange', 'FRUIT'),
    ('carrot trout', 'VEGETABLE|FISH'),
    ('orange apple', 'FRUIT'),
    ('orange', 'ORANGE|CITRUS')  # 2 OOV tag
])

# 6 distinct tokens
CLIENT2_DATASET = make_dataset([
    ('pear cod', 'FRUIT|FISH'),
    ('arugula peas', 'VEGETABLE'),
    ('kiwi pear', 'FRUIT'),
    ('sturgeon', 'FISH'),  # OOV word
    ('sturgeon bass', 'FISH')  # 2 OOV words
])

# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET = make_dataset([
    (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)),
    # Mathe the OOV token and 'salmon' occur in the largest number
    # of examples on this client:
    ('salmon oovword', 'FISH|OOVTAG')
])

print('Word vocab')
for i, word in enumerate(WORD_VOCAB):
  print(f'{i:2d} {word}')

print('\nTag vocab')
for i, tag in enumerate(TAG_VOCAB):
  print(f'{i:2d} {tag}')
Word vocab
 0 apple
 1 orange
 2 pear
 3 kiwi
 4 carrot
 5 broccoli
 6 arugula
 7 peas
 8 trout
 9 tuna
10 cod
11 salmon

Tag vocab
 0 FRUIT
 1 VEGETABLE
 2 FISH

Xác định các hằng số cho số lượng nguyên của các tính năng đầu vào (mã thông báo / từ) và nhãn (thẻ bài đăng). Không gian đầu vào / đầu ra thực tế của chúng tôi là NUM_OOV_BUCKETS = 1 lớn hơn bởi vì chúng ta thêm một OOV thẻ / thẻ.

NUM_WORDS = len(WORD_VOCAB) 
NUM_TAGS = len(TAG_VOCAB)

WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

Tạo các phiên bản theo lô của tập dữ liệu và các lô riêng lẻ, điều này sẽ hữu ích trong việc kiểm tra mã khi chúng ta tiếp tục.

batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)

batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))

Xác định một mô hình có đầu vào thưa thớt

Chúng tôi sử dụng mô hình hồi quy logistic độc lập đơn giản cho mỗi thẻ.

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int):

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True),
      tf.keras.layers.Dense(
          vocab_tags_size,
          activation='sigmoid',
          kernel_initializer=tf.keras.initializers.zeros,
          # For simplicity, don't use a bias vector; this means the model
          # is a single tensor, and we only need sparse aggregation of
          # the per-token slices of the model. Generalizing to also handle
          # other model weights that are fully updated 
          # (non-dense broadcast and aggregate) would be a good exercise.
          use_bias=False),
  ])

  return model

Hãy đảm bảo rằng nó hoạt động, trước tiên bằng cách đưa ra các dự đoán:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5]]

Và một số đào tạo tập trung đơn giản:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
              loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)

Các khối xây dựng cho phép tính liên hợp

Chúng tôi sẽ thực hiện một phiên bản đơn giản của trung bình Federated thuật toán với sự khác biệt quan trọng mà mỗi thiết bị chỉ tải về một tập hợp con có liên quan của các mô hình, và chỉ góp phần cập nhật để tập hợp con đó.

Chúng tôi sử dụng M là viết tắt cho MAX_TOKENS_SELECTED_PER_CLIENT . Ở cấp độ cao, một vòng đào tạo bao gồm các bước sau:

  1. Mỗi ứng dụng khách tham gia quét qua tập dữ liệu cục bộ của nó, phân tích cú pháp các chuỗi đầu vào và ánh xạ chúng tới các mã thông báo chính xác (int indexes). Điều này đòi hỏi quyền truy cập vào toàn cầu từ điển (lớn) (điều này có khả năng có thể tránh sử dụng tính năng băm kỹ thuật). Sau đó, chúng tôi đếm số lần mỗi mã thông báo xuất hiện một cách thưa thớt. Nếu U thẻ độc đáo xảy ra trên thiết bị, chúng tôi chọn num_actual_tokens = min(U, M) hầu hết các thẻ thường xuyên để đào tạo.

  2. Các khách hàng sử dụng federated_select để lấy các hệ số mô hình cho num_actual_tokens chọn thẻ từ máy chủ. Mỗi mô hình lát là một tensor hình dạng (TAG_VOCAB_SIZE, ) , vì vậy tổng số dữ liệu truyền tới khách hàng là tại hầu hết các kích thước TAG_VOCAB_SIZE * M (xem lưu ý dưới đây).

  3. Các khách hàng xây dựng một bản đồ global_token -> local_token nơi token địa phương (int index) là chỉ số của dấu hiệu toàn cầu trong danh sách các thẻ được chọn.

  4. Các khách hàng sử dụng một phiên bản "nhỏ" của mô hình toàn cầu mà chỉ có hệ số trong ít nhất M tokens, từ dãy [0, num_actual_tokens) . Các global -> local lập bản đồ được sử dụng để khởi tạo các thông số dày đặc của mô hình này từ mô hình lát chọn.

  5. Khách hàng đào tạo mô hình địa phương của họ sử dụng SGD trên dữ liệu xử lý trước với global -> local lập bản đồ.

  6. Khách hàng lần lượt các thông số của mô hình địa phương của họ vào IndexedSlices cập nhật bằng cách sử dụng local -> global bản đồ để chỉ mục các hàng. Máy chủ tổng hợp các bản cập nhật này bằng cách sử dụng tập hợp tổng thưa thớt.

  7. Máy chủ lấy kết quả (dày đặc) của tổng hợp ở trên, chia nó cho số lượng khách hàng tham gia và áp dụng cập nhật trung bình kết quả cho mô hình toàn cầu.

Trong phần này chúng ta xây dựng các khối xây dựng cho các bước này, sau đó sẽ được kết hợp trong một trận chung kết federated_computation rằng ảnh chụp toàn bộ logic của một vòng đào tạo.

Đếm thẻ khách hàng và quyết định mô hình lát để federated_select

Mỗi thiết bị cần quyết định "lát cắt" nào của mô hình có liên quan đến tập dữ liệu đào tạo cục bộ của nó. Đối với vấn đề của chúng tôi, chúng tôi thực hiện điều này bằng cách (thưa thớt!) Đếm xem có bao nhiêu ví dụ chứa mỗi mã thông báo trong tập dữ liệu đào tạo khách hàng.

@tf.function
def token_count_fn(token_counts, batch):
  """Adds counts from `batch` to the running `token_counts` sum."""
  # Sum across the batch dimension.
  flat_tokens = tf.sparse.reduce_sum(
      batch.tokens, axis=0, output_is_sparse=True)
  flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE)
  return tf.sparse.add(token_counts, flat_tokens)
# Simple tests
# Create the initial zero token counts using empty tensors.
initial_token_counts = tf.SparseTensor(
    indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE),
    values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE),
    dense_shape=(WORD_VOCAB_SIZE,))

client_token_counts = batched_dataset1.reduce(initial_token_counts,
                                              token_count_fn)
tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy()
print('tokens:', tokens)
np.testing.assert_array_equal(tokens, [0, 1, 4, 8])
# The count is the number of *examples* in which the token/word
# occurs, not the total number of occurences, since we still featurize
# multiple occurences in the same example as a "1".
counts = client_token_counts.values.numpy()
print('counts:', counts)
np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8]
counts: [2 3 1 1]

Chúng tôi sẽ lựa chọn các thông số mô hình tương ứng với MAX_TOKENS_SELECTED_PER_CLIENT thường xuyên nhất xảy ra tokens trên thiết bị. Nếu ít hơn nhiều thẻ này xảy ra trên thiết bị, chúng tôi pad danh sách để cho phép việc sử dụng các federated_select .

Lưu ý rằng các chiến lược khác có thể tốt hơn, ví dụ: chọn ngẫu nhiên các mã thông báo (có thể dựa trên xác suất xuất hiện của chúng). Điều này sẽ đảm bảo rằng tất cả các phần của mô hình (mà khách hàng có dữ liệu) đều có cơ hội được cập nhật.

@tf.function
def keys_for_client(client_dataset, max_tokens_per_client):
  """Computes a set of max_tokens_per_client keys."""
  initial_token_counts = tf.SparseTensor(
      indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE),
      values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE),
      dense_shape=(WORD_VOCAB_SIZE,))
  client_token_counts = client_dataset.reduce(initial_token_counts,
                                              token_count_fn)
  # Find the most-frequently occuring tokens
  tokens = tf.reshape(client_token_counts.indices, shape=(-1,))
  counts = client_token_counts.values
  perm = tf.argsort(counts, direction='DESCENDING')
  tokens = tf.gather(tokens, perm)
  counts = tf.gather(counts, perm)
  num_raw_tokens = tf.shape(tokens)[0]
  actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens)
  selected_tokens = tokens[:actual_num_tokens]
  paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]]
  padded_tokens = tf.pad(selected_tokens, paddings=paddings)
  # Make sure the type is statically determined
  padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,))

  # We will pass these tokens as keys into `federated_select`, which
  # requires SELECT_KEY_DTYPE=tf.int32 keys.
  padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE)
  return padded_tokens, actual_num_tokens
# Simple test

# Case 1: actual_num_tokens > max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3)
assert tf.size(selected_tokens) == 3
assert actual_num_tokens == 3

# Case 2: actual_num_tokens < max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10)
assert tf.size(selected_tokens) == 10
assert actual_num_tokens == 4

Ánh xạ mã thông báo toàn cầu với mã thông báo địa phương

Việc lựa chọn ở trên cho chúng ta một tập trù mật của thẻ trong phạm vi [0, actual_num_tokens) mà chúng tôi sẽ sử dụng cho các mô hình trên thiết bị. Tuy nhiên, các số liệu chúng ta đọc có thẻ từ lớn hơn nhiều phạm vi toàn cầu từ vựng [0, WORD_VOCAB_SIZE) .

Do đó, chúng ta cần ánh xạ các mã thông báo toàn cầu với các mã thông báo địa phương tương ứng của chúng. Id thẻ địa phương chỉ đơn giản được đưa ra bởi các chỉ số vào selected_tokens tensor tính trong bước trước.

@tf.function
def map_to_local_token_ids(client_data, client_keys):
  global_to_local = tf.lookup.StaticHashTable(
      # Note int32 -> int64 maps are not supported
      tf.lookup.KeyValueTensorInitializer(
          keys=tf.cast(client_keys, dtype=TOKEN_DTYPE),
          # Note we need to use tf.shape, not the static 
          # shape client_keys.shape[0]
          values=tf.range(0, limit=tf.shape(client_keys)[0],
                          dtype=TOKEN_DTYPE)),
      # We use -1 for tokens that were not selected, which can occur for clients
      # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens.
      # We will simply remove these invalid indices from the batch below.
      default_value=-1)

  def to_local_ids(sparse_tokens):
    indices_t = tf.transpose(sparse_tokens.indices)
    batch_indices = indices_t[0]  # First column
    tokens = indices_t[1]  # Second column
    tokens = tf.map_fn(
        lambda global_token_id: global_to_local.lookup(global_token_id), tokens)
    # Remove tokens that aren't actually available (looked up as -1):
    available_tokens = tokens >= 0
    tokens = tokens[available_tokens]
    batch_indices = batch_indices[available_tokens]

    updated_indices = tf.transpose(
        tf.concat([[batch_indices], [tokens]], axis=0))
    st = tf.sparse.SparseTensor(
        updated_indices,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=sparse_tokens.dense_shape)
    st = tf.sparse.reorder(st)
    return st

  return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys = client_keys[:actual_num_tokens]

d = map_to_local_token_ids(batched_dataset3, client_keys)
batch  = next(iter(d))
all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1)
# Confirm we have local indices in the range [0, MAX):
assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
assert tf.math.reduce_max(all_tokens) >= 0

Đào tạo mô hình cục bộ (phụ) trên mỗi khách hàng

Lưu ý federated_select sẽ trả lại lát chọn là một tf.data.Dataset theo thứ tự giống như các phím chọn. Vì vậy, trước tiên chúng ta xác định một hàm tiện ích để lấy một Tập dữ liệu như vậy và chuyển đổi nó thành một tensor dày đặc duy nhất có thể được sử dụng làm trọng số mô hình của mô hình khách hàng.

@tf.function
def slices_dataset_to_tensor(slices_dataset):
  """Convert a dataset of slices to a tensor."""
  # Use batching to gather all of the slices into a single tensor.
  d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
                           drop_remainder=False)
  iter_d = iter(d)
  tensor = next(iter_d)
  # Make sure we have consumed everything
  opt = iter_d.get_next_as_optional()
  tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY')
  return tensor
# Simple test
weights = np.random.random(
    size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights)
weights2 = slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights, weights2)

Bây giờ chúng ta có tất cả các thành phần chúng ta cần để xác định một vòng lặp huấn luyện cục bộ đơn giản sẽ chạy trên mỗi máy khách.

@tf.function
def client_train_fn(model, client_optimizer,
                    model_slices_as_dataset, client_data,
                    client_keys, actual_num_tokens):

  initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset)
  assert len(model.trainable_variables) == 1
  model.trainable_variables[0].assign(initial_model_weights)

  # Only keep the "real" (unpadded) keys.
  client_keys = client_keys[:actual_num_tokens]

  client_data = map_to_local_token_ids(client_data, client_keys)

  loss_fn = tf.keras.losses.BinaryCrossentropy()
  for features, labels in client_data:
    with tf.GradientTape() as tape:
      predictions = model(features)
      loss = loss_fn(labels, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    client_optimizer.apply_gradients(zip(grads, model.trainable_variables))

  model_weights_delta = model.trainable_weights[0] - initial_model_weights
  model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], 
                           size=[actual_num_tokens, -1])
  return client_keys, model_weights_delta
# Simple test
# Note if you execute this cell a second time, you need to also re-execute
# the preceeding cell to avoid "tf.function-decorated function tried to 
# create variables on non-first call" errors.
on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                        TAG_VOCAB_SIZE)
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT)

model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(
    np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE),
             dtype=np.float32))

keys, delta = client_train_fn(
    on_device_model,
    client_optimizer,
    model_slices_as_dataset,
    client_data=batched_dataset3,
    client_keys=client_keys,
    actual_num_tokens=actual_num_tokens)

print(delta)

Aggregate IndexedSlices

Chúng tôi sử dụng tff.federated_aggregate để xây dựng một tổng thưa thớt liên cho IndexedSlices . Thực hiện đơn giản này có ràng buộc là các dense_shape được biết đến tĩnh trước. Cũng lưu ý rằng số tiền này là chỉ bán thưa thớt, theo nghĩa là khách hàng -> giao tiếp máy chủ là thưa thớt, nhưng máy chủ duy trì một đại diện dày đặc của tổng trong accumulatemerge , và kết quả đầu ra đại diện dày đặc này.

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape):
  """
  Sumes IndexedSlices@CLIENTS to a dense @SERVER Tensor.

  Intermediate aggregation is performed by converting to a dense representation,
  which may not be suitable for all applications.

  Args:
    slice_indices: An IndexedSlices.indices tensor @CLIENTS.
    slice_values: An IndexedSlices.values tensor @CLIENTS.
    dense_shape: A statically known dense shape.

  Returns:
    A dense tensor placed @SERVER representing the sum of the client's
    IndexedSclies.
  """
  slices_dtype = slice_values.type_signature.member.dtype
  zero = tff.tf_computation(
      lambda: tf.zeros(dense_shape, dtype=slices_dtype))()

  @tf.function
  def accumulate_slices(dense, client_value):
    indices, slices = client_value
    # There is no built-in way to add `IndexedSlices`, but 
    # tf.convert_to_tensor is a quick way to convert to a dense representation
    # so we can add them.
    return dense + tf.convert_to_tensor(
        tf.IndexedSlices(slices, indices, dense_shape))


  return tff.federated_aggregate(
      (slice_indices, slice_values),
      zero=zero,
      accumulate=tff.tf_computation(accumulate_slices),
      merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')),
      report=tff.tf_computation(lambda d: d))

Xây dựng một tối thiểu federated_computation như một thử nghiệm

dense_shape = (6, 2)
indices_type = tff.TensorType(tf.int64, (None,))
values_type = tff.TensorType(tf.float32, (None, 2))
client_slice_type = tff.type_at_clients(
    (indices_type, values_type))

@tff.federated_computation(client_slice_type)
def test_sum_indexed_slices(indices_values_at_client):
  indices, values = indices_values_at_client
  return federated_indexed_slices_sum(indices, values, dense_shape)

print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices(
    values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]],
                    dtype=np.float32),
    indices=[2, 0, 1, 5],
    dense_shape=dense_shape)
y = tf.IndexedSlices(
    values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32),
    indices=[1, 3],
    dense_shape=dense_shape)

# Sum one.
result = test_sum_indexed_slices([(x.indices, x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x), result)

# Sum two.
expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]]
result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)])
np.testing.assert_array_almost_equal(expected, result)

Đưa nó tất cả cùng nhau trong một federated_computation

Bây giờ chúng ta sử dụng TFF để ràng buộc cùng các thành phần vào một tff.federated_computation .

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
client_data_type = tff.SequenceType(batched_dataset1.element_spec)
model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

Chúng tôi sử dụng chức năng đào tạo máy chủ cơ bản dựa trên Tính trung bình liên kết, áp dụng bản cập nhật với tốc độ học tập của máy chủ là 1,0. Điều quan trọng là chúng tôi phải áp dụng bản cập nhật (delta) cho mô hình, thay vì chỉ tính trung bình các mô hình do khách hàng cung cấp, vì nếu không, nếu một phần nhất định của mô hình không được khách hàng nào đào tạo trong một vòng nhất định thì hệ số của nó có thể bằng không ngoài.

@tff.tf_computation
def server_update(current_model_weights, update_sum, num_clients):
  average_update = update_sum / num_clients
  return current_model_weights + average_update

Chúng tôi cần thêm một vài tff.tf_computation thành phần:

# Function to select slices from the model weights in federated_select:
select_fn = tff.tf_computation(
    lambda model_weights, index: tf.gather(model_weights, index))


# We need to wrap `client_train_fn` as a `tff.tf_computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tf_computation
def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys,
                        actual_num_tokens):
  # Note this is amaller than the global model, using
  # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
  # W7e would like a model of size `actual_num_tokens`, but we
  # can't build the model dynamically, so we will slice off the padded
  # weights at the end.
  client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                       TAG_VOCAB_SIZE)
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
  return client_train_fn(client_model, client_optimizer,
                         model_slices_as_dataset, client_data, client_keys,
                         actual_num_tokens)

@tff.tf_computation
def keys_for_client_tff(client_data):
  return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

Bây giờ chúng tôi đã sẵn sàng để ghép tất cả các mảnh lại với nhau!

@tff.federated_computation(
    tff.type_at_server(model_type), tff.type_at_clients(client_data_type))
def sparse_model_update(server_model, client_data):
  max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER)
  keys_at_clients, actual_num_tokens = tff.federated_map(
      keys_for_client_tff, client_data)

  model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model,
                                      select_fn)

  update_keys, update_slices = tff.federated_map(
      client_train_fn_tff,
      (model_slices, client_data, keys_at_clients, actual_num_tokens))

  dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices,
                                                  DENSE_MODEL_SHAPE)
  num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS))

  updated_server_model = tff.federated_map(
      server_update, (server_model, dense_update_sum, num_clients))

  return updated_server_model


print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

Hãy đào tạo một người mẫu!

Bây giờ chúng ta đã có chức năng đào tạo của mình, hãy thử nó.

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile(  # Compile to make evaluation easy.
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0),  # Unused
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[ 
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
  ])

def evaluate(model, dataset, name):
  metrics = model.evaluate(dataset, verbose=0)
  metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in 
                          (zip(server_model.metrics_names, metrics))])
  print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')

model_weights = server_model.trainable_weights[0]

client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3]
for _ in range(10):  # Run 10 rounds of FedAvg
  # We train on 1, 2, or 3 clients per round, selecting
  # randomly.
  cohort_size = np.random.randint(1, 4)
  clients = np.random.choice([0, 1, 2], cohort_size, replace=False)
  print('Training on clients', clients)
  model_weights = sparse_model_update(
      model_weights, [client_datasets[i] for i in clients])
server_model.set_weights([model_weights])

print('After training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')
Before training
Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60
Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50
Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40
Training on clients [0 1]
Training on clients [0 2 1]
Training on clients [2 0]
Training on clients [1 0 2]
Training on clients [2]
Training on clients [2 0]
Training on clients [1 2 0]
Training on clients [0]
Training on clients [2]
Training on clients [1 2]
After training
Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80
Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00
Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80