Wydajna dla klienta sfederowana nauka dużego modelu za pośrednictwem federated_select i sparse agregation

Zobacz na TensorFlow.org Uruchom w Google Colab Wyświetl źródło na GitHub Pobierz notatnik

Ten tutorial pokazuje jak mogą być wykorzystane do TFF trenować bardzo duży model, w którym każde urządzenie klienckie tylko pobierać i aktualizować niewielką część modelu, używając tff.federated_select i agregację nieliczne. Chociaż ten poradnik jest dość powściągliwy The tff.federated_select poradnik i zwyczaj FL algorytmy poradnik zapewniają dobre wstępy do niektórych technik tutaj.

Konkretnie, w tym samouczku rozważymy regresję logistyczną dla klasyfikacji wieloetykietowej, przewidując, które „tagi” są skojarzone z ciągiem tekstowym na podstawie reprezentacji funkcji worka słów. Co ważne, koszty komunikacji i obliczeniowe po stronie klienta są kontrolowane przez ustalonej stałej ( MAX_TOKENS_SELECTED_PER_CLIENT ), a nie skala z ogólnej wielkości słownictwa, które mogą być bardzo duże w praktycznych ustawień.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import itertools
import numpy as np

from typing import Callable, List, Tuple

import tensorflow as tf
import tensorflow_federated as tff
tff.backends.native.set_local_python_execution_context()

Każdy klient będzie federated_select wiersze modelu ciężarów przez co większość tego wielu unikalnych żetonów. Ta górna aut wielkości modelu lokalnego klienta oraz kwotę serwer -> klient ( federated_select ) i klient -> serwer (federated_aggregate ) komunikacja wykonywana.

Ten samouczek powinien nadal działać poprawnie, nawet jeśli ustawisz wartość na 1 (upewniając się, że nie wszystkie tokeny z każdego klienta są wybrane) lub na dużą wartość, chociaż może wystąpić zbieżność modelu.

MAX_TOKENS_SELECTED_PER_CLIENT = 6

Definiujemy również kilka stałych dla różnych typów. Z tego colab, token jest identyfikatorem całkowitą dla danego słowa po parsowania zestawu danych.

# There are some constraints on types
# here that will require some explicit type conversions:
#    - `tff.federated_select` requires int32
#    - `tf.SparseTensor` requires int64 indices.
TOKEN_DTYPE = tf.int64
SELECT_KEY_DTYPE = tf.int32

# Type for counts of token occurences.
TOKEN_COUNT_DTYPE = tf.int32

# A sparse feature vector can be thought of as a map
# from TOKEN_DTYPE to FEATURE_DTYPE. 
# Our features are {0, 1} indicators, so we could potentially
# use tf.int8 as an optimization.
FEATURE_DTYPE = tf.int32

Konfiguracja problemu: zbiór danych i model

W tym samouczku konstruujemy mały zestaw danych zabawek, aby ułatwić eksperymentowanie. Jednak format zbioru danych jest kompatybilny z Federalne StackOverflow i wstępne przetwarzanie i modelu architektury są przyjmowane z StackOverflow predykcji tag problemu adaptacyjnego Federalne Optimization .

Parsowanie i wstępne przetwarzanie zbioru danych

NUM_OOV_BUCKETS = 1

BatchType = collections.namedtuple('BatchType', ['tokens', 'tags'])

def build_to_ids_fn(word_vocab: List[str],
                    tag_vocab: List[str]) -> Callable[[tf.Tensor], tf.Tensor]:
  """Constructs a function mapping examples to sequences of token indices."""
  word_table_values = np.arange(len(word_vocab), dtype=np.int64)
  word_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(word_vocab, word_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  tag_table_values = np.arange(len(tag_vocab), dtype=np.int64)
  tag_table = tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(tag_vocab, tag_table_values),
      num_oov_buckets=NUM_OOV_BUCKETS)

  def to_ids(example):
    """Converts a Stack Overflow example to a bag-of-words/tags format."""
    sentence = tf.strings.join([example['tokens'], example['title']],
                               separator=' ')

    # We represent that label (output tags) densely.
    raw_tags = example['tags']
    tags = tf.strings.split(raw_tags, sep='|')
    tags = tag_table.lookup(tags)
    tags, _ = tf.unique(tags)
    tags = tf.one_hot(tags, len(tag_vocab) + NUM_OOV_BUCKETS)
    tags = tf.reduce_max(tags, axis=0)

    # We represent the features as a SparseTensor of {0, 1}s.
    words = tf.strings.split(sentence)
    tokens = word_table.lookup(words)
    tokens, _ = tf.unique(tokens)
    # Note:  We could choose to use the word counts as the feature vector
    # instead of just {0, 1} values (see tf.unique_with_counts).
    tokens = tf.reshape(tokens, shape=(tf.size(tokens), 1))
    tokens_st = tf.SparseTensor(
        tokens,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=(len(word_vocab) + NUM_OOV_BUCKETS,))
    tokens_st = tf.sparse.reorder(tokens_st)

    return BatchType(tokens_st, tags)

  return to_ids
def build_preprocess_fn(word_vocab, tag_vocab):

  @tf.function
  def preprocess_fn(dataset):
    to_ids = build_to_ids_fn(word_vocab, tag_vocab)
    # We *don't* shuffle in order to make this colab deterministic for
    # easier testing and reproducibility.
    # But real-world training should use `.shuffle()`.
    return dataset.map(to_ids, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  return preprocess_fn

Mały zbiór danych zabawek

Tworzymy mały zbiór danych zabawek z globalnym słownictwem 12 słów i 3 klientami. Ten drobny przykładem jest przydatna do badania przypadków krawędź (na przykład, dwa klientom poniżej MAX_TOKENS_SELECTED_PER_CLIENT = 6 różnych znaczników i jeden o więcej) i tworzenia kodu.

Jednak rzeczywistymi przypadkami użycia tego podejścia byłyby globalne słowniki dziesiątek milionów lub więcej, z prawdopodobnie tysiącami różnych tokenów pojawiających się na każdym kliencie. Ponieważ format danych jest taka sama, rozszerzenie do bardziej realistycznych problemach platform testowych, np tff.simulation.datasets.stackoverflow.load_data() zbioru danych, powinny być proste.

Najpierw definiujemy nasze słowniki słów i tagów.

# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS

# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

Teraz tworzymy 3 klientów z małymi lokalnymi zestawami danych. Jeśli korzystasz z tego samouczka w colab, przydatne może być użycie funkcji „lustrzanej komórki na karcie”, aby przypiąć tę komórkę i jej dane wyjściowe w celu zinterpretowania/sprawdzenia wyników poniższych funkcji.

preprocess_fn = build_preprocess_fn(WORD_VOCAB, TAG_VOCAB)


def make_dataset(raw):
  d = tf.data.Dataset.from_tensor_slices(
      # Matches the StackOverflow formatting
      collections.OrderedDict(
          tokens=tf.constant([t[0] for t in raw]),
          tags=tf.constant([t[1] for t in raw]),
          title=['' for _ in raw]))
  d = preprocess_fn(d)
  return d


# 4 distinct tokens
CLIENT1_DATASET = make_dataset([
    ('apple orange apple orange', 'FRUIT'),
    ('carrot trout', 'VEGETABLE|FISH'),
    ('orange apple', 'FRUIT'),
    ('orange', 'ORANGE|CITRUS')  # 2 OOV tag
])

# 6 distinct tokens
CLIENT2_DATASET = make_dataset([
    ('pear cod', 'FRUIT|FISH'),
    ('arugula peas', 'VEGETABLE'),
    ('kiwi pear', 'FRUIT'),
    ('sturgeon', 'FISH'),  # OOV word
    ('sturgeon bass', 'FISH')  # 2 OOV words
])

# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET = make_dataset([
    (' '.join(WORD_VOCAB + ['oovword']), '|'.join(TAG_VOCAB)),
    # Mathe the OOV token and 'salmon' occur in the largest number
    # of examples on this client:
    ('salmon oovword', 'FISH|OOVTAG')
])

print('Word vocab')
for i, word in enumerate(WORD_VOCAB):
  print(f'{i:2d} {word}')

print('\nTag vocab')
for i, tag in enumerate(TAG_VOCAB):
  print(f'{i:2d} {tag}')
Word vocab
 0 apple
 1 orange
 2 pear
 3 kiwi
 4 carrot
 5 broccoli
 6 arugula
 7 peas
 8 trout
 9 tuna
10 cod
11 salmon

Tag vocab
 0 FRUIT
 1 VEGETABLE
 2 FISH

Zdefiniuj stałe dla surowych liczb funkcji wejściowych (tokeny/słowa) i etykiet (znaczniki postów). Nasze rzeczywiste przestrzenie wejścia / wyjścia są NUM_OOV_BUCKETS = 1 większy, ponieważ dodajemy OOV żeton / tag.

NUM_WORDS = len(WORD_VOCAB) 
NUM_TAGS = len(TAG_VOCAB)

WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

Twórz wsadowe wersje zestawów danych i pojedyncze partie, które będą przydatne podczas testowania kodu.

batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)

batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))

Zdefiniuj model z rzadkimi danymi wejściowymi

Dla każdego tagu używamy prostego, niezależnego modelu regresji logistycznej.

def create_logistic_model(word_vocab_size: int, vocab_tags_size: int):

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(word_vocab_size,), sparse=True),
      tf.keras.layers.Dense(
          vocab_tags_size,
          activation='sigmoid',
          kernel_initializer=tf.keras.initializers.zeros,
          # For simplicity, don't use a bias vector; this means the model
          # is a single tensor, and we only need sparse aggregation of
          # the per-token slices of the model. Generalizing to also handle
          # other model weights that are fully updated 
          # (non-dense broadcast and aggregate) would be a good exercise.
          use_bias=False),
  ])

  return model

Upewnijmy się, że to działa, najpierw dokonując prognoz:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5]]

I kilka prostych scentralizowanych szkoleń:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
              loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)

Bloki konstrukcyjne do obliczeń federacyjnych

Będziemy realizować wersję prostego w uśredniania Federalne algorytmu z kluczem różnicą, że każde urządzenie pobiera tylko odpowiedni podzbiór modelu, a jedynie przyczynia się dostawać do tego podzbioru.

Używamy M jako skrót dla MAX_TOKENS_SELECTED_PER_CLIENT . Na wysokim poziomie jedna runda szkolenia obejmuje następujące kroki:

  1. Każdy uczestniczący klient skanuje swój lokalny zestaw danych, analizując ciągi wejściowe i mapując je na prawidłowe tokeny (indeksy int). To wymaga dostępu do globalnej (duży) słowniku (może to potencjalnie można uniknąć stosując cecha hashowania techniki). Następnie rzadko liczymy, ile razy występuje każdy token. Jeśli U unikatowe znaczniki występuje w urządzeniu, to wybranie num_actual_tokens = min(U, M) najczęstszych tokeny pociągu.

  2. Klienci korzystać federated_select do pobierania modelu współczynniki dla num_actual_tokens wybranych znaków z serwera. Każdy model plaster jest tensor kształtu (TAG_VOCAB_SIZE, ) , więc łączne dane przesyłane do klienta jest co najwyżej o rozmiarze TAG_VOCAB_SIZE * M (patrz uwaga poniżej).

  3. Klienci skonstruować mapowania global_token -> local_token gdzie miejscowy żeton (int index) to wskaźnik globalnej tokena na liście wybranych żetonów.

  4. Klienci korzystać z „małą” wersję globalnego modelu, który ma tylko współczynniki co najwyżej M żetonów, z przedziału [0, num_actual_tokens) . global -> local odwzorowania jest używany do inicjacji gęste parametry tego modelu od wybranego modelu plasterki.

  5. Klienci trenować swój lokalny model używając SGD danych wstępnie przygotowane z global -> local odwzorowania.

  6. Klienci włączyć parametry lokalnego modelu do IndexedSlices aktualizacji przy użyciu local -> global mapowania do indeksu wierszy. Serwer agreguje te aktualizacje przy użyciu agregacji sum rzadkich.

  7. Serwer pobiera (gęsty) wynik powyższej agregacji, dzieli go przez liczbę uczestniczących klientów i stosuje wynikową średnią aktualizację do modelu globalnego.

W tej sekcji budujemy budulec dla tych kroków, które następnie zostaną połączone w końcowym federated_computation który przechwytuje pełne logika jednej serii treningowej.

Liczyć żetony klienckie i zdecydować, który model plastry do federated_select

Każde urządzenie musi zdecydować, które „wycinki” modelu są odpowiednie dla jego lokalnego zestawu danych treningowych. Dla naszego problemu robimy to przez (rzadko!) liczenie, ile przykładów zawiera każdy token w zbiorze danych uczących klienta.

@tf.function
def token_count_fn(token_counts, batch):
  """Adds counts from `batch` to the running `token_counts` sum."""
  # Sum across the batch dimension.
  flat_tokens = tf.sparse.reduce_sum(
      batch.tokens, axis=0, output_is_sparse=True)
  flat_tokens = tf.cast(flat_tokens, dtype=TOKEN_COUNT_DTYPE)
  return tf.sparse.add(token_counts, flat_tokens)
# Simple tests
# Create the initial zero token counts using empty tensors.
initial_token_counts = tf.SparseTensor(
    indices=tf.zeros(shape=(0, 1), dtype=TOKEN_DTYPE),
    values=tf.zeros(shape=(0,), dtype=TOKEN_COUNT_DTYPE),
    dense_shape=(WORD_VOCAB_SIZE,))

client_token_counts = batched_dataset1.reduce(initial_token_counts,
                                              token_count_fn)
tokens = tf.reshape(client_token_counts.indices, (-1,)).numpy()
print('tokens:', tokens)
np.testing.assert_array_equal(tokens, [0, 1, 4, 8])
# The count is the number of *examples* in which the token/word
# occurs, not the total number of occurences, since we still featurize
# multiple occurences in the same example as a "1".
counts = client_token_counts.values.numpy()
print('counts:', counts)
np.testing.assert_array_equal(counts, [2, 3, 1, 1])
tokens: [0 1 4 8]
counts: [2 3 1 1]

Będziemy wybierać parametry modelu odpowiadające MAX_TOKENS_SELECTED_PER_CLIENT najczęściej występujące znaki na urządzeniu. Jeśli mniej niż tylu żetonów występują na urządzeniu, my pad lista w celu umożliwienia korzystania z federated_select .

Zauważ, że inne strategie są prawdopodobnie lepsze, na przykład losowe wybieranie tokenów (być może na podstawie prawdopodobieństwa ich wystąpienia). Zapewniłoby to, że wszystkie wycinki modelu (dla których klient ma dane) mają pewną szansę na aktualizację.

@tf.function
def keys_for_client(client_dataset, max_tokens_per_client):
  """Computes a set of max_tokens_per_client keys."""
  initial_token_counts = tf.SparseTensor(
      indices=tf.zeros((0, 1), dtype=TOKEN_DTYPE),
      values=tf.zeros((0,), dtype=TOKEN_COUNT_DTYPE),
      dense_shape=(WORD_VOCAB_SIZE,))
  client_token_counts = client_dataset.reduce(initial_token_counts,
                                              token_count_fn)
  # Find the most-frequently occuring tokens
  tokens = tf.reshape(client_token_counts.indices, shape=(-1,))
  counts = client_token_counts.values
  perm = tf.argsort(counts, direction='DESCENDING')
  tokens = tf.gather(tokens, perm)
  counts = tf.gather(counts, perm)
  num_raw_tokens = tf.shape(tokens)[0]
  actual_num_tokens = tf.minimum(max_tokens_per_client, num_raw_tokens)
  selected_tokens = tokens[:actual_num_tokens]
  paddings = [[0, max_tokens_per_client - tf.shape(selected_tokens)[0]]]
  padded_tokens = tf.pad(selected_tokens, paddings=paddings)
  # Make sure the type is statically determined
  padded_tokens = tf.reshape(padded_tokens, shape=(max_tokens_per_client,))

  # We will pass these tokens as keys into `federated_select`, which
  # requires SELECT_KEY_DTYPE=tf.int32 keys.
  padded_tokens = tf.cast(padded_tokens, dtype=SELECT_KEY_DTYPE)
  return padded_tokens, actual_num_tokens
# Simple test

# Case 1: actual_num_tokens > max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 3)
assert tf.size(selected_tokens) == 3
assert actual_num_tokens == 3

# Case 2: actual_num_tokens < max_tokens_per_client
selected_tokens, actual_num_tokens = keys_for_client(batched_dataset1, 10)
assert tf.size(selected_tokens) == 10
assert actual_num_tokens == 4

Mapuj globalne tokeny na lokalne tokeny

Powyższy wybór daje nam gęsty zestaw żetonów w przedziale [0, actual_num_tokens) , które będziemy używać do modelu na urządzeniu. Jednak zestaw danych czytamy posiada znaki ze znacznie większym zasięgu globalnym słownictwa [0, WORD_VOCAB_SIZE) .

Dlatego musimy zmapować globalne tokeny na odpowiadające im tokeny lokalne. Lokalne symboliczne identyfikatory są po prostu podane przez indeksy do selected_tokens tensor obliczonych w poprzednim kroku.

@tf.function
def map_to_local_token_ids(client_data, client_keys):
  global_to_local = tf.lookup.StaticHashTable(
      # Note int32 -> int64 maps are not supported
      tf.lookup.KeyValueTensorInitializer(
          keys=tf.cast(client_keys, dtype=TOKEN_DTYPE),
          # Note we need to use tf.shape, not the static 
          # shape client_keys.shape[0]
          values=tf.range(0, limit=tf.shape(client_keys)[0],
                          dtype=TOKEN_DTYPE)),
      # We use -1 for tokens that were not selected, which can occur for clients
      # with more than MAX_TOKENS_SELECTED_PER_CLIENT distinct tokens.
      # We will simply remove these invalid indices from the batch below.
      default_value=-1)

  def to_local_ids(sparse_tokens):
    indices_t = tf.transpose(sparse_tokens.indices)
    batch_indices = indices_t[0]  # First column
    tokens = indices_t[1]  # Second column
    tokens = tf.map_fn(
        lambda global_token_id: global_to_local.lookup(global_token_id), tokens)
    # Remove tokens that aren't actually available (looked up as -1):
    available_tokens = tokens >= 0
    tokens = tokens[available_tokens]
    batch_indices = batch_indices[available_tokens]

    updated_indices = tf.transpose(
        tf.concat([[batch_indices], [tokens]], axis=0))
    st = tf.sparse.SparseTensor(
        updated_indices,
        tf.ones(tf.size(tokens), dtype=FEATURE_DTYPE),
        dense_shape=sparse_tokens.dense_shape)
    st = tf.sparse.reorder(st)
    return st

  return client_data.map(lambda b: BatchType(to_local_ids(b.tokens), b.tags))
# Simple test
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset3, MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys = client_keys[:actual_num_tokens]

d = map_to_local_token_ids(batched_dataset3, client_keys)
batch  = next(iter(d))
all_tokens = tf.gather(batch.tokens.indices, indices=1, axis=1)
# Confirm we have local indices in the range [0, MAX):
assert tf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
assert tf.math.reduce_max(all_tokens) >= 0

Trenuj lokalny (pod)model na każdym kliencie

Uwaga federated_select powróci wybrane plasterki jako tf.data.Dataset w tej samej kolejności co klawiszy wyboru. Tak więc najpierw definiujemy funkcję użytkową, aby wziąć taki zestaw danych i przekonwertować go na pojedynczy gęsty tensor, który może być używany jako wagi modelu modelu klienta.

@tf.function
def slices_dataset_to_tensor(slices_dataset):
  """Convert a dataset of slices to a tensor."""
  # Use batching to gather all of the slices into a single tensor.
  d = slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
                           drop_remainder=False)
  iter_d = iter(d)
  tensor = next(iter_d)
  # Make sure we have consumed everything
  opt = iter_d.get_next_as_optional()
  tf.Assert(tf.logical_not(opt.has_value()), data=[''], name='CHECK_EMPTY')
  return tensor
# Simple test
weights = np.random.random(
    size=(MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(weights)
weights2 = slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights, weights2)

Mamy teraz wszystkie komponenty potrzebne do zdefiniowania prostej lokalnej pętli szkoleniowej, która będzie działać na każdym kliencie.

@tf.function
def client_train_fn(model, client_optimizer,
                    model_slices_as_dataset, client_data,
                    client_keys, actual_num_tokens):

  initial_model_weights = slices_dataset_to_tensor(model_slices_as_dataset)
  assert len(model.trainable_variables) == 1
  model.trainable_variables[0].assign(initial_model_weights)

  # Only keep the "real" (unpadded) keys.
  client_keys = client_keys[:actual_num_tokens]

  client_data = map_to_local_token_ids(client_data, client_keys)

  loss_fn = tf.keras.losses.BinaryCrossentropy()
  for features, labels in client_data:
    with tf.GradientTape() as tape:
      predictions = model(features)
      loss = loss_fn(labels, predictions)
    grads = tape.gradient(loss, model.trainable_variables)
    client_optimizer.apply_gradients(zip(grads, model.trainable_variables))

  model_weights_delta = model.trainable_weights[0] - initial_model_weights
  model_weights_delta = tf.slice(model_weights_delta, begin=[0, 0], 
                           size=[actual_num_tokens, -1])
  return client_keys, model_weights_delta
# Simple test
# Note if you execute this cell a second time, you need to also re-execute
# the preceeding cell to avoid "tf.function-decorated function tried to 
# create variables on non-first call" errors.
on_device_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                        TAG_VOCAB_SIZE)
client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys, actual_num_tokens = keys_for_client(
    batched_dataset2, MAX_TOKENS_SELECTED_PER_CLIENT)

model_slices_as_dataset = tf.data.Dataset.from_tensor_slices(
    np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT, TAG_VOCAB_SIZE),
             dtype=np.float32))

keys, delta = client_train_fn(
    on_device_model,
    client_optimizer,
    model_slices_as_dataset,
    client_data=batched_dataset3,
    client_keys=client_keys,
    actual_num_tokens=actual_num_tokens)

print(delta)

Zagregowane zindeksowane plastry

Używamy tff.federated_aggregate skonstruować stowarzyszonego rzadki sumę za IndexedSlices . Ta prosta implementacja ma ograniczenia, że dense_shape jest znany statycznie z góry. Należy również zauważyć, że suma ta jest tylko pół-rzadki w tym sensie, że klient -> serwer komunikacyjny jest rzadki, ale serwer utrzymuje zwartą reprezentację sumy w accumulate i merge i wysyła tę gęstą reprezentacji.

def federated_indexed_slices_sum(slice_indices, slice_values, dense_shape):
  """
  Sumes IndexedSlices@CLIENTS to a dense @SERVER Tensor.

  Intermediate aggregation is performed by converting to a dense representation,
  which may not be suitable for all applications.

  Args:
    slice_indices: An IndexedSlices.indices tensor @CLIENTS.
    slice_values: An IndexedSlices.values tensor @CLIENTS.
    dense_shape: A statically known dense shape.

  Returns:
    A dense tensor placed @SERVER representing the sum of the client's
    IndexedSclies.
  """
  slices_dtype = slice_values.type_signature.member.dtype
  zero = tff.tf_computation(
      lambda: tf.zeros(dense_shape, dtype=slices_dtype))()

  @tf.function
  def accumulate_slices(dense, client_value):
    indices, slices = client_value
    # There is no built-in way to add `IndexedSlices`, but 
    # tf.convert_to_tensor is a quick way to convert to a dense representation
    # so we can add them.
    return dense + tf.convert_to_tensor(
        tf.IndexedSlices(slices, indices, dense_shape))


  return tff.federated_aggregate(
      (slice_indices, slice_values),
      zero=zero,
      accumulate=tff.tf_computation(accumulate_slices),
      merge=tff.tf_computation(lambda d1, d2: tf.add(d1, d2, name='merge')),
      report=tff.tf_computation(lambda d: d))

Skonstruować minimalny federated_computation jako test

dense_shape = (6, 2)
indices_type = tff.TensorType(tf.int64, (None,))
values_type = tff.TensorType(tf.float32, (None, 2))
client_slice_type = tff.type_at_clients(
    (indices_type, values_type))

@tff.federated_computation(client_slice_type)
def test_sum_indexed_slices(indices_values_at_client):
  indices, values = indices_values_at_client
  return federated_indexed_slices_sum(indices, values, dense_shape)

print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x = tf.IndexedSlices(
    values=np.array([[2., 2.1], [0., 0.1], [1., 1.1], [5., 5.1]],
                    dtype=np.float32),
    indices=[2, 0, 1, 5],
    dense_shape=dense_shape)
y = tf.IndexedSlices(
    values=np.array([[0., 0.3], [3.1, 3.2]], dtype=np.float32),
    indices=[1, 3],
    dense_shape=dense_shape)

# Sum one.
result = test_sum_indexed_slices([(x.indices, x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x), result)

# Sum two.
expected = [[0., 0.1], [1., 1.4], [2., 2.1], [3.1, 3.2], [0., 0.], [5., 5.1]]
result = test_sum_indexed_slices([(x.indices, x.values), (y.indices, y.values)])
np.testing.assert_array_almost_equal(expected, result)

Wprowadzenie go wszyscy razem w federated_computation

Teraz używa TFF do związania ze sobą składników do tff.federated_computation .

DENSE_MODEL_SHAPE = (WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
client_data_type = tff.SequenceType(batched_dataset1.element_spec)
model_type = tff.TensorType(tf.float32, shape=DENSE_MODEL_SHAPE)

Używamy podstawowej funkcji uczenia serwera opartej na uśrednianiu federacyjnym, stosując aktualizację z szybkością uczenia serwera 1,0. Ważne jest, abyśmy zastosowali aktualizację (delta) do modelu, zamiast po prostu uśredniać modele dostarczone przez klienta, ponieważ w przeciwnym razie, jeśli dany wycinek modelu nie został przeszkolony przez żadnego klienta w danej rundzie, jego współczynniki mogą zostać wyzerowane na zewnątrz.

@tff.tf_computation
def server_update(current_model_weights, update_sum, num_clients):
  average_update = update_sum / num_clients
  return current_model_weights + average_update

Musimy jeszcze kilka tff.tf_computation składniki:

# Function to select slices from the model weights in federated_select:
select_fn = tff.tf_computation(
    lambda model_weights, index: tf.gather(model_weights, index))


# We need to wrap `client_train_fn` as a `tff.tf_computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tf_computation
def client_train_fn_tff(model_slices_as_dataset, client_data, client_keys,
                        actual_num_tokens):
  # Note this is amaller than the global model, using
  # MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
  # W7e would like a model of size `actual_num_tokens`, but we
  # can't build the model dynamically, so we will slice off the padded
  # weights at the end.
  client_model = create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
                                       TAG_VOCAB_SIZE)
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.1)
  return client_train_fn(client_model, client_optimizer,
                         model_slices_as_dataset, client_data, client_keys,
                         actual_num_tokens)

@tff.tf_computation
def keys_for_client_tff(client_data):
  return keys_for_client(client_data, MAX_TOKENS_SELECTED_PER_CLIENT)

Jesteśmy teraz gotowi, aby złożyć wszystkie elementy razem!

@tff.federated_computation(
    tff.type_at_server(model_type), tff.type_at_clients(client_data_type))
def sparse_model_update(server_model, client_data):
  max_tokens = tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT, tff.SERVER)
  keys_at_clients, actual_num_tokens = tff.federated_map(
      keys_for_client_tff, client_data)

  model_slices = tff.federated_select(keys_at_clients, max_tokens, server_model,
                                      select_fn)

  update_keys, update_slices = tff.federated_map(
      client_train_fn_tff,
      (model_slices, client_data, keys_at_clients, actual_num_tokens))

  dense_update_sum = federated_indexed_slices_sum(update_keys, update_slices,
                                                  DENSE_MODEL_SHAPE)
  num_clients = tff.federated_sum(tff.federated_value(1.0, tff.CLIENTS))

  updated_server_model = tff.federated_map(
      server_update, (server_model, dense_update_sum, num_clients))

  return updated_server_model


print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

Wytrenujmy modelkę!

Teraz, gdy mamy naszą funkcję treningową, wypróbujmy ją.

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile(  # Compile to make evaluation easy.
    optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0),  # Unused
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[ 
      tf.keras.metrics.Precision(name='precision'),
      tf.keras.metrics.AUC(name='auc'),
      tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
  ])

def evaluate(model, dataset, name):
  metrics = model.evaluate(dataset, verbose=0)
  metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in 
                          (zip(server_model.metrics_names, metrics))])
  print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')

model_weights = server_model.trainable_weights[0]

client_datasets = [batched_dataset1, batched_dataset2, batched_dataset3]
for _ in range(10):  # Run 10 rounds of FedAvg
  # We train on 1, 2, or 3 clients per round, selecting
  # randomly.
  cohort_size = np.random.randint(1, 4)
  clients = np.random.choice([0, 1, 2], cohort_size, replace=False)
  print('Training on clients', clients)
  model_weights = sparse_model_update(
      model_weights, [client_datasets[i] for i in clients])
server_model.set_weights([model_weights])

print('After training')
evaluate(server_model, batched_dataset1, 'Client 1')
evaluate(server_model, batched_dataset2, 'Client 2')
evaluate(server_model, batched_dataset3, 'Client 3')
Before training
Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60
Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50
Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40
Training on clients [0 1]
Training on clients [0 2 1]
Training on clients [2 0]
Training on clients [1 0 2]
Training on clients [2]
Training on clients [2 0]
Training on clients [1 2 0]
Training on clients [0]
Training on clients [2]
Training on clients [1 2]
After training
Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80
Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00
Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80