Pembelajaran Terpadu untuk Klasifikasi Gambar

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Dalam tutorial ini, kita menggunakan contoh pelatihan MNIST klasik untuk memperkenalkan Federasi Learning (FL) lapisan API dari TFF, tff.learning - satu set antarmuka tingkat yang lebih tinggi yang dapat digunakan untuk melakukan jenis umum dari tugas-tugas belajar federasi, seperti pelatihan gabungan, terhadap model yang disediakan pengguna yang diimplementasikan di TensorFlow.

Tutorial ini, dan Federated Learning API, ditujukan terutama untuk pengguna yang ingin menyambungkan model TensorFlow mereka sendiri ke TFF, memperlakukan yang terakhir sebagian besar sebagai kotak hitam. Untuk pemahaman yang lebih mendalam tentang TFF dan bagaimana menerapkan algoritma pembelajaran Federasi Anda sendiri, lihat tutorial pada API FC Core - Kustom Federasi Algoritma Bagian 1 dan Bagian 2 .

Untuk lebih lanjut tentang tff.learning , lanjutkan dengan Federasi Belajar untuk Teks Generation , tutorial yang selain meliputi model berulang, juga menunjukkan memuat model Keras pra-dilatih serial untuk perbaikan dengan belajar Federasi dikombinasikan dengan evaluasi menggunakan Keras.

Sebelum kita mulai

Sebelum kita mulai, jalankan yang berikut ini untuk memastikan bahwa lingkungan Anda telah diatur dengan benar. Jika Anda tidak melihat salam, silakan merujuk ke Instalasi panduan untuk petunjuk.

# tensorflow_federated_nightly also bring in tf_nightly, which
# can causes a duplicate tensorboard install, leading to errors.
!pip uninstall --yes tensorboard tb-nightly

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio
!pip install --quiet --upgrade tb-nightly  # or tensorboard, but not both

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

Menyiapkan data masukan

Mari kita mulai dengan datanya. Pembelajaran gabungan memerlukan kumpulan data gabungan, yaitu kumpulan data dari banyak pengguna. Data federasi biasanya non iid , yang menimbulkan serangkaian tantangan yang unik.

Dalam rangka memfasilitasi eksperimen, kami diunggulkan repositori TFF dengan beberapa dataset, termasuk versi federasi dari MNIST yang berisi versi NIST dataset asli yang telah diproses ulang dengan menggunakan daun sehingga data yang bersemangat oleh penulis asli angka. Karena setiap penulis memiliki gaya yang unik, kumpulan data ini menunjukkan jenis perilaku non-iid yang diharapkan dari kumpulan data gabungan.

Inilah cara kita dapat memuatnya.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Data set dikembalikan oleh load_data() adalah contoh dari tff.simulation.ClientData , sebuah antarmuka yang memungkinkan Anda untuk menghitung set pengguna, untuk membangun sebuah tf.data.Dataset yang mewakili data pengguna tertentu, dan untuk query struktur elemen individu. Inilah cara Anda dapat menggunakan antarmuka ini untuk menjelajahi konten kumpulan data. Perlu diingat bahwa sementara antarmuka ini memungkinkan Anda untuk beralih pada id klien, ini hanya fitur dari data simulasi. Seperti yang akan Anda lihat segera, identitas klien tidak digunakan oleh kerangka pembelajaran federasi - satu-satunya tujuan mereka adalah untuk memungkinkan Anda memilih subset data untuk simulasi.

len(emnist_train.client_ids)
3383
emnist_train.element_type_structure
OrderedDict([('label', TensorSpec(shape=(), dtype=tf.int32, name=None)), ('pixels', TensorSpec(shape=(28, 28), dtype=tf.float32, name=None))])
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

example_element = next(iter(example_dataset))

example_element['label'].numpy()
1
from matplotlib import pyplot as plt
plt.imshow(example_element['pixels'].numpy(), cmap='gray', aspect='equal')
plt.grid(False)
_ = plt.show()

png

Menjelajahi heterogenitas dalam data gabungan

Data federasi biasanya non iid , pengguna biasanya memiliki distribusi data yang berbeda tergantung pada pola penggunaan. Beberapa klien mungkin memiliki lebih sedikit contoh pelatihan pada perangkat, menderita kekurangan data secara lokal, sementara beberapa klien akan memiliki lebih dari cukup contoh pelatihan. Mari kita jelajahi konsep heterogenitas data yang khas dari sistem federasi dengan data EMNIST yang kami miliki. Penting untuk dicatat bahwa analisis mendalam terhadap data klien ini hanya tersedia bagi kami karena ini adalah lingkungan simulasi di mana semua data tersedia bagi kami secara lokal. Dalam lingkungan federasi produksi nyata, Anda tidak akan dapat memeriksa data klien tunggal.

Pertama, mari ambil contoh data satu klien untuk merasakan contoh pada satu perangkat simulasi. Karena kumpulan data yang kami gunakan telah dikunci oleh penulis unik, data satu klien mewakili tulisan tangan satu orang untuk sampel angka 0 hingga 9, yang mensimulasikan "pola penggunaan" unik dari satu pengguna.

## Example MNIST digits for one client
figure = plt.figure(figsize=(20, 4))
j = 0

for example in example_dataset.take(40):
  plt.subplot(4, 10, j+1)
  plt.imshow(example['pixels'].numpy(), cmap='gray', aspect='equal')
  plt.axis('off')
  j += 1

png

Sekarang mari kita visualisasikan jumlah contoh pada setiap klien untuk setiap label digit MNIST. Dalam lingkungan federasi, jumlah contoh pada setiap klien dapat sedikit berbeda, tergantung pada perilaku pengguna.

# Number of examples per layer for a sample of clients
f = plt.figure(figsize=(12, 7))
f.suptitle('Label Counts for a Sample of Clients')
for i in range(6):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    # Append counts individually per label to make plots
    # more colorful instead of one color per plot.
    label = example['label'].numpy()
    plot_data[label].append(label)
  plt.subplot(2, 3, i+1)
  plt.title('Client {}'.format(i))
  for j in range(10):
    plt.hist(
        plot_data[j],
        density=False,
        bins=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

png

Sekarang mari kita visualisasikan gambar rata-rata per klien untuk setiap label MNIST. Kode ini akan menghasilkan rata-rata setiap nilai piksel untuk semua contoh pengguna untuk satu label. Kita akan melihat bahwa gambar rata-rata satu klien untuk satu digit akan terlihat berbeda dari gambar rata-rata klien lain untuk angka yang sama, karena gaya tulisan tangan yang unik dari setiap orang. Kita dapat merenungkan tentang bagaimana setiap putaran pelatihan lokal akan mendorong model ke arah yang berbeda pada setiap klien, karena kita belajar dari data unik pengguna itu sendiri di putaran lokal tersebut. Nanti dalam tutorial kita akan melihat bagaimana kita dapat mengambil setiap pembaruan model dari semua klien dan menggabungkannya ke dalam model global baru kita, yang telah dipelajari dari setiap data unik klien kita sendiri.

# Each client has different mean images, meaning each client will be nudging
# the model in their own directions locally.

for i in range(5):
  client_dataset = emnist_train.create_tf_dataset_for_client(
      emnist_train.client_ids[i])
  plot_data = collections.defaultdict(list)
  for example in client_dataset:
    plot_data[example['label'].numpy()].append(example['pixels'].numpy())
  f = plt.figure(i, figsize=(12, 5))
  f.suptitle("Client #{}'s Mean Image Per Label".format(i))
  for j in range(10):
    mean_img = np.mean(plot_data[j], 0)
    plt.subplot(2, 5, j+1)
    plt.imshow(mean_img.reshape((28, 28)))
    plt.axis('off')

png

png

png

png

png

Data pengguna bisa berisik dan diberi label yang tidak dapat diandalkan. Sebagai contoh, melihat data Klien #2 di atas, kita dapat melihat bahwa untuk label 2, mungkin ada beberapa contoh yang salah label sehingga menghasilkan gambar yang lebih berisik.

Memproses data input terlebih dahulu

Karena data sudah menjadi tf.data.Dataset , preprocessing dapat dicapai dengan menggunakan Dataset transformasi. Di sini, kita meratakan 28x28 gambar ke 784 array -element, mengocok contoh individu, mengatur mereka ke dalam batch, dan mengubah nama fitur dari pixels dan label untuk x dan y untuk digunakan dengan Keras. Kami juga melemparkan dalam repeat atas kumpulan data untuk menjalankan beberapa zaman.

NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

Mari kita verifikasi ini berhasil.

preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch
OrderedDict([('x', array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)), ('y', array([[2],
       [1],
       [5],
       [7],
       [1],
       [7],
       [7],
       [1],
       [4],
       [7],
       [4],
       [2],
       [2],
       [5],
       [4],
       [1],
       [1],
       [0],
       [0],
       [9]], dtype=int32))])

Kami memiliki hampir semua blok penyusun untuk membangun kumpulan data gabungan.

Salah satu cara untuk memberi makan data yang federasi untuk TFF dalam simulasi hanyalah sebagai daftar Python, dengan setiap elemen dari daftar memegang data pengguna individu, baik sebagai daftar atau sebagai tf.data.Dataset . Karena kita sudah memiliki antarmuka yang menyediakan yang terakhir, mari kita gunakan.

Berikut adalah fungsi pembantu sederhana yang akan menyusun daftar kumpulan data dari kumpulan pengguna yang diberikan sebagai masukan untuk putaran pelatihan atau evaluasi.

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

Sekarang, bagaimana kita memilih klien?

Dalam skenario pelatihan gabungan yang khas, kita berhadapan dengan populasi perangkat pengguna yang berpotensi sangat besar, hanya sebagian kecil yang mungkin tersedia untuk pelatihan pada titik waktu tertentu. Ini adalah kasusnya, misalnya, ketika perangkat klien adalah ponsel yang berpartisipasi dalam pelatihan hanya ketika dicolokkan ke sumber daya, di luar jaringan terukur, dan jika tidak, menganggur.

Tentu saja, kita berada dalam lingkungan simulasi, dan semua data tersedia secara lokal. Biasanya, saat menjalankan simulasi, kami hanya akan mengambil sampel acak dari klien untuk dilibatkan dalam setiap putaran pelatihan, umumnya berbeda di setiap putaran.

Yang mengatakan, karena Anda dapat mengetahui dengan mempelajari kertas pada Averaging Federasi algoritma, mencapai konvergensi dalam sistem dengan subset acak sampel dari klien di setiap putaran dapat mengambil beberapa saat, dan akan tidak praktis harus menjalankan ratusan putaran di tutorial interaktif ini.

Apa yang akan kami lakukan adalah mengambil sampel kumpulan klien sekali, dan menggunakan kembali kumpulan yang sama di seluruh putaran untuk mempercepat konvergensi (sengaja terlalu pas dengan beberapa data pengguna ini). Kami membiarkannya sebagai latihan bagi pembaca untuk memodifikasi tutorial ini untuk mensimulasikan pengambilan sampel acak - ini cukup mudah dilakukan (setelah Anda melakukannya, perlu diingat bahwa mendapatkan model untuk menyatu mungkin memakan waktu cukup lama).

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]

federated_train_data = make_federated_data(emnist_train, sample_clients)

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))
Number of client datasets: 10
First dataset: <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>

Membuat model dengan Keras

Jika Anda menggunakan Keras, Anda mungkin sudah memiliki kode yang membangun model Keras. Berikut adalah contoh model sederhana yang akan mencukupi kebutuhan kita.

def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

Dalam rangka untuk menggunakan model dengan TFF, perlu dibungkus dalam sebuah instance dari tff.learning.Model antarmuka, yang memaparkan metode untuk cap model depan lulus, sifat metadata, dll, mirip dengan Keras, tetapi juga memperkenalkan tambahan elemen, seperti cara untuk mengontrol proses penghitungan metrik gabungan. Jangan khawatir tentang ini untuk saat ini; jika Anda memiliki model Keras seperti yang baru saja kita ditentukan di atas, Anda dapat memiliki TFF membungkusnya untuk Anda dengan menerapkan tff.learning.from_keras_model , melewati model dan batch data sampel sebagai argumen, seperti yang ditunjukkan di bawah ini.

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Melatih model pada data gabungan

Sekarang kita memiliki model dibungkus sebagai tff.learning.Model untuk digunakan dengan TFF, kita dapat membiarkan TFF membangun algoritma Averaging Federated dengan menerapkan fungsi helper tff.learning.build_federated_averaging_process , sebagai berikut.

Perlu diketahui bahwa argumen perlu konstruktor (seperti model_fn di atas), bukan sebuah contoh yang sudah dibangun, sehingga pembangunan model Anda dapat terjadi dalam konteks dikendalikan oleh TFF (jika Anda sedang ingin tahu tentang alasan ini, kami mendorong Anda untuk membaca tutorial tindak lanjut atas algoritma kustom ).

Satu catatan penting pada Averaging algoritma Federasi bawah, ada 2 pengoptimalan: optimasi _client dan optimizer _SERVER. The _client optimizer hanya digunakan untuk menghitung update Model lokal pada setiap klien. The _SERVER optimizer berlaku update rata-rata untuk model global pada server. Secara khusus, ini berarti bahwa pilihan pengoptimal dan kecepatan pembelajaran yang digunakan mungkin perlu berbeda dari yang Anda gunakan untuk melatih model pada kumpulan data iid standar. Kami merekomendasikan untuk memulai dengan SGD reguler, mungkin dengan tingkat pembelajaran yang lebih kecil dari biasanya. Tingkat pembelajaran yang kami gunakan belum disetel dengan hati-hati, jangan ragu untuk bereksperimen.

iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Apa yang baru saja terjadi? TFF telah membangun sepasang perhitungan federasi dan dikemas ke dalam sebuah tff.templates.IterativeProcess di mana perhitungan ini tersedia sebagai sepasang sifat initialize dan next .

Singkatnya, perhitungan federasi adalah program di bahasa internal TFF yang dapat mengekspresikan berbagai algoritma federasi (Anda dapat menemukan lebih lanjut tentang ini dalam adat algoritma tutorial). Dalam hal ini, dua perhitungan yang dihasilkan dan dikemas ke dalam iterative_process menerapkan Federasi Averaging .

Ini adalah tujuan TFF untuk mendefinisikan perhitungan dengan cara yang mereka dapat dieksekusi dalam pengaturan pembelajaran federasi nyata, tetapi saat ini hanya runtime simulasi eksekusi lokal yang diterapkan. Untuk menjalankan komputasi dalam simulator, Anda cukup memanggilnya seperti fungsi Python. Lingkungan interpretasi default ini tidak dirancang untuk kinerja tinggi, tetapi cukup untuk tutorial ini; kami berharap dapat menyediakan runtime simulasi berkinerja lebih tinggi untuk memfasilitasi penelitian skala besar di rilis mendatang.

Mari kita mulai dengan initialize perhitungan. Seperti halnya untuk semua komputasi gabungan, Anda dapat menganggapnya sebagai fungsi. Komputasi tidak memerlukan argumen, dan mengembalikan satu hasil - representasi status proses Rata-Rata Federasi di server. Meskipun kami tidak ingin menyelami detail TFF, mungkin bermanfaat untuk melihat seperti apa keadaan ini. Anda dapat memvisualisasikannya sebagai berikut.

str(iterative_process.initialize.type_signature)
'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<value_sum_process=<>,weight_sum_process=<>>,model_broadcast_state=<>>@SERVER)'

Sedangkan tipe tanda tangan di atas mungkin pada awalnya tampak samar sedikit, Anda dapat mengenali bahwa server negara terdiri dari model (model parameter awal untuk MNIST yang akan didistribusikan ke semua perangkat), dan optimizer_state (informasi tambahan dikelola oleh server, seperti jumlah putaran yang digunakan untuk jadwal hyperparameter, dll.).

Mari kita memanggil initialize perhitungan untuk membangun server negara.

state = iterative_process.initialize()

Kedua dari pasangan perhitungan federasi, next , merupakan satu putaran Federasi Averaging, yang terdiri dari mendorong negara Server (termasuk parameter model) kepada klien, pada perangkat pelatihan data lokal mereka, mengumpulkan dan update Model averaging , dan menghasilkan model baru yang diperbarui di server.

Secara konseptual, Anda bisa memikirkan next sebagai memiliki tanda tangan jenis fungsional yang terlihat sebagai berikut.

SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

Secara khusus, salah satu harus berpikir tentang next() tidak sebagai fungsi yang berjalan di server, melainkan menjadi representasi fungsional deklaratif dari seluruh perhitungan desentralisasi - beberapa input yang disediakan oleh server ( SERVER_STATE ), tetapi masing-masing peserta perangkat menyumbangkan set data lokalnya sendiri.

Mari kita jalankan satu putaran pelatihan dan visualisasikan hasilnya. Kami dapat menggunakan data gabungan yang telah kami buat di atas untuk sampel pengguna.

state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12345679), ('loss', 3.1193738)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Mari kita jalankan beberapa putaran lagi. Seperti disebutkan sebelumnya, biasanya pada titik ini Anda akan memilih subset data simulasi Anda dari sampel pengguna baru yang dipilih secara acak untuk setiap putaran untuk mensimulasikan penerapan realistis di mana pengguna terus datang dan pergi, tetapi dalam buku catatan interaktif ini, untuk demi demonstrasi, kami hanya akan menggunakan kembali pengguna yang sama, sehingga sistem dapat menyatu dengan cepat.

NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13518518), ('loss', 2.9834728)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14382716), ('loss', 2.861665)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.17407407), ('loss', 2.7957022)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.19917695), ('loss', 2.6146567)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21975309), ('loss', 2.529761)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2409465), ('loss', 2.4053504)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2611111), ('loss', 2.315389)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.30823046), ('loss', 2.1240263)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.33312756), ('loss', 2.1164262)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Kehilangan pelatihan menurun setelah setiap putaran pelatihan gabungan, menunjukkan model konvergen. Ada beberapa keberatan penting dengan metrik pelatihan ini, bagaimanapun, lihat bagian Evaluasi nanti dalam tutorial ini.

Menampilkan metrik model di TensorBoard

Selanjutnya, mari visualisasikan metrik dari komputasi gabungan ini menggunakan Tensorboard.

Mari kita mulai dengan membuat direktori dan penulis ringkasan yang sesuai untuk menulis metrik.

logdir = "/tmp/logs/scalars/training/"
summary_writer = tf.summary.create_file_writer(logdir)
state = iterative_process.initialize()

Plot metrik skalar yang relevan dengan penulis ringkasan yang sama.

with summary_writer.as_default():
  for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    for name, value in metrics['train'].items():
      tf.summary.scalar(name, value, step=round_num)

Mulai TensorBoard dengan direktori root log yang ditentukan di atas. Diperlukan beberapa detik untuk memuat data.

!ls {logdir}
%tensorboard --logdir {logdir} --port=0
events.out.tfevents.1629557449.ebe6e776479e64ea-4903924a278.borgtask.google.com.458912.1.v2
Launching TensorBoard...
Reusing TensorBoard on port 50681 (pid 292785), started 0:30:30 ago. (Use '!kill 292785' to kill it.)
<IPython.core.display.Javascript at 0x7fd6617e02d0>
# Uncomment and run this this cell to clean your directory of old output for
# future graphs from this directory. We don't run it by default so that if 
# you do a "Runtime > Run all" you don't lose your results.

# !rm -R /tmp/logs/scalars/*

Untuk melihat metrik evaluasi dengan cara yang sama, Anda dapat membuat folder eval terpisah, seperti "logs/scalars/eval", untuk menulis ke TensorBoard.

Menyesuaikan implementasi model

Keras adalah direkomendasikan tingkat tinggi Model API untuk TensorFlow , dan kami mendorong menggunakan model Keras (melalui tff.learning.from_keras_model ) di TFF bila memungkinkan.

Namun, tff.learning menyediakan antarmuka model tingkat yang lebih rendah, tff.learning.Model , yang mengekspos fungsi minimal yang diperlukan untuk menggunakan model pembelajaran federasi. Langsung mengimplementasikan interface ini (mungkin masih menggunakan blok bangunan seperti tf.keras.layers ) memungkinkan untuk kustomisasi maksimum tanpa memodifikasi internal dari algoritma pembelajaran Federasi.

Jadi mari kita lakukan semuanya lagi dari awal.

Mendefinisikan variabel model, forward pass, dan metrik

Langkah pertama adalah mengidentifikasi variabel TensorFlow yang akan kita kerjakan. Untuk membuat kode berikut lebih mudah dibaca, mari kita definisikan struktur data untuk mewakili seluruh rangkaian. Ini akan mencakup variabel seperti weights dan bias bahwa kita akan melatih, serta variabel yang akan menggelar berbagai statistik kumulatif dan counter kami akan memperbarui selama pelatihan, seperti loss_sum , accuracy_sum , dan num_examples .

MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Berikut adalah metode yang membuat variabel. Demi kesederhanaan, kami mewakili semua statistik sebagai tf.float32 , seperti yang akan menghilangkan kebutuhan untuk konversi tipe pada tahap berikutnya. Pembungkus initializers variabel sebagai lambdas adalah persyaratan yang diberlakukan oleh variabel sumber daya .

def create_mnist_variables():
  return MnistVariables(
      weights=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias=tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples=tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum=tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum=tf.Variable(0.0, name='accuracy_sum', trainable=False))

Dengan variabel untuk parameter model dan statistik kumulatif, sekarang kita dapat mendefinisikan metode forward pass yang menghitung kerugian, memancarkan prediksi, dan memperbarui statistik kumulatif untuk satu kumpulan data input, sebagai berikut.

def predict_on_batch(variables, x):
  return tf.nn.softmax(tf.matmul(x, variables.weights) + variables.bias)

def mnist_forward_pass(variables, batch):
  y = predict_on_batch(variables, batch['x'])
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  flat_labels = tf.reshape(batch['y'], [-1])
  loss = -tf.reduce_mean(
      tf.reduce_sum(tf.one_hot(flat_labels, 10) * tf.math.log(y), axis=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, flat_labels), tf.float32))

  num_examples = tf.cast(tf.size(batch['y']), tf.float32)

  variables.num_examples.assign_add(num_examples)
  variables.loss_sum.assign_add(loss * num_examples)
  variables.accuracy_sum.assign_add(accuracy * num_examples)

  return loss, predictions

Selanjutnya, kami mendefinisikan fungsi yang mengembalikan sekumpulan metrik lokal, sekali lagi menggunakan TensorFlow. Ini adalah nilai (selain pembaruan model, yang ditangani secara otomatis) yang memenuhi syarat untuk digabungkan ke server dalam proses pembelajaran atau evaluasi gabungan.

Di sini, kita hanya mengembalikan rata-rata loss dan accuracy , serta num_examples , yang kita harus benar berat kontribusi dari pengguna yang berbeda ketika menghitung agregat Federasi.

def get_local_mnist_metrics(variables):
  return collections.OrderedDict(
      num_examples=variables.num_examples,
      loss=variables.loss_sum / variables.num_examples,
      accuracy=variables.accuracy_sum / variables.num_examples)

Akhirnya, kita perlu menentukan bagaimana agregat metrik lokal yang dipancarkan oleh masing-masing perangkat melalui get_local_mnist_metrics . Ini adalah satu-satunya bagian dari kode yang tidak ditulis dalam TensorFlow - itu adalah perhitungan federasi dinyatakan dalam TFF. Jika Anda ingin menggali lebih dalam, meluncur di atas kustom algoritma tutorial, tetapi dalam banyak aplikasi, Anda tidak akan benar-benar perlu; varian dari pola yang ditunjukkan di bawah ini sudah cukup. Berikut tampilannya:

@tff.federated_computation
def aggregate_mnist_metrics_across_clients(metrics):
  return collections.OrderedDict(
      num_examples=tff.federated_sum(metrics.num_examples),
      loss=tff.federated_mean(metrics.loss, metrics.num_examples),
      accuracy=tff.federated_mean(metrics.accuracy, metrics.num_examples))

Input metrics argumen berkorespondensi dengan OrderedDict dikembalikan oleh get_local_mnist_metrics di atas, tetapi kritis nilai-nilai tidak lagi tf.Tensors - mereka adalah "kotak" sebagai tff.Value s, untuk membuatnya jelas Anda tidak lagi dapat memanipulasi mereka menggunakan TensorFlow, tetapi hanya menggunakan operator federasi TFF seperti tff.federated_mean dan tff.federated_sum . Kamus agregat global yang dikembalikan mendefinisikan kumpulan metrik yang akan tersedia di server.

Membangun sebuah contoh dari tff.learning.Model

Dengan semua hal di atas, kami siap membuat representasi model untuk digunakan dengan TFF serupa dengan yang dibuat untuk Anda saat Anda membiarkan TFF menyerap model Keras.

from typing import Callable, List, OrderedDict

class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict(
        x=tf.TensorSpec([None, 784], tf.float32),
        y=tf.TensorSpec([None, 1], tf.int32))

  @tf.function
  def predict_on_batch(self, x, training=True):
    del training
    return predict_on_batch(self._variables, x)

  @tf.function
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    num_exmaples = tf.shape(batch['x'])[0]
    return tff.learning.BatchOutput(
        loss=loss, predictions=predictions, num_examples=num_exmaples)

  @tf.function
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_mnist_metrics_across_clients

  @tf.function
  def report_local_unfinalized_metrics(
      self) -> OrderedDict[str, List[tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to unfinalized values."""
    return collections.OrderedDict(
        num_examples=[self._variables.num_examples],
        loss=[self._variables.loss_sum, self._variables.num_examples],
        accuracy=[self._variables.accuracy_sum, self._variables.num_examples])

  def metric_finalizers(
      self) -> OrderedDict[str, Callable[[List[tf.Tensor]], tf.Tensor]]:
    """Creates an `OrderedDict` of metric names to finalizers."""
    return collections.OrderedDict(
        num_examples=tf.function(func=lambda x: x[0]),
        loss=tf.function(func=lambda x: x[0] / x[1]),
        accuracy=tf.function(func=lambda x: x[0] / x[1]))

Seperti yang Anda lihat, metode abstrak dan properti yang didefinisikan oleh tff.learning.Model berkorespondensi dengan potongan kode di bagian sebelumnya yang memperkenalkan variabel dan mendefinisikan kerugian dan statistik.

Berikut adalah beberapa poin yang patut disoroti:

  • Semua negara yang model Anda akan menggunakan harus ditangkap sebagai variabel TensorFlow, seperti TFF tidak menggunakan Python saat runtime (ingat kode Anda harus ditulis sedemikian rupa sehingga dapat digunakan untuk perangkat mobile; melihat kustom algoritma tutorial untuk lebih mendalam komentar tentang alasan).
  • Model Anda harus menjelaskan apa bentuk data yang menerima ( input_spec ), seperti pada umumnya, TFF adalah lingkungan kuat-mengetik dan ingin untuk menentukan tanda tangan jenis untuk semua komponen. Mendeklarasikan format input model Anda adalah bagian penting darinya.
  • Meskipun secara teknis tidak diperlukan, kami sarankan membungkus semua TensorFlow logika (depan lulus, metrik perhitungan, dll) sebagai tf.function s, karena hal ini membantu memastikan TensorFlow dapat serial, dan menghilangkan kebutuhan untuk dependensi kontrol eksplisit.

Di atas sudah cukup untuk evaluasi dan algoritma seperti Federated SGD. Namun, untuk Federated Averaging, kita perlu menentukan bagaimana model harus dilatih secara lokal pada setiap batch. Kami akan menentukan pengoptimal lokal saat membangun algoritma Rata-Rata Federasi.

Mensimulasikan pelatihan gabungan dengan model baru

Dengan semua hal di atas, sisa proses terlihat seperti apa yang telah kita lihat - cukup ganti konstruktor model dengan konstruktor kelas model baru kita, dan gunakan dua komputasi gabungan dalam proses iteratif yang Anda buat untuk siklus putaran pelatihan.

iterative_process = tff.learning.build_federated_averaging_process(
    MnistModel,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02))
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, federated_train_data)
print('round  1, metrics={}'.format(metrics))
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.0708053), ('accuracy', 0.12777779)])), ('stat', OrderedDict([('num_examples', 4860)]))])
for round_num in range(2, 11):
  state, metrics = iterative_process.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 3.011699), ('accuracy', 0.13024691)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.7408307), ('accuracy', 0.15576132)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.6761012), ('accuracy', 0.17921811)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.675567), ('accuracy', 0.1855967)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.5664043), ('accuracy', 0.20329218)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.4179392), ('accuracy', 0.24382716)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.3237286), ('accuracy', 0.26687244)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.1861682), ('accuracy', 0.28209877)])), ('stat', OrderedDict([('num_examples', 4860)]))])
round 10, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('num_examples', 4860.0), ('loss', 2.046388), ('accuracy', 0.32037038)])), ('stat', OrderedDict([('num_examples', 4860)]))])

Untuk melihat metrik ini dalam TensorBoard, lihat langkah-langkah yang tercantum di atas dalam "Menampilkan metrik model di TensorBoard".

Evaluasi

Semua eksperimen kami sejauh ini hanya menyajikan metrik pelatihan gabungan - metrik rata-rata untuk semua kumpulan data yang dilatih di semua klien dalam putaran tersebut. Ini memperkenalkan kekhawatiran normal tentang overfitting, terutama karena kami menggunakan set klien yang sama pada setiap putaran untuk kesederhanaan, tetapi ada gagasan tambahan tentang overfitting dalam metrik pelatihan khusus untuk algoritma Federated Averaging. Ini paling mudah untuk dilihat jika kita membayangkan setiap klien memiliki satu kumpulan data, dan kami melatih kumpulan itu untuk banyak iterasi (zaman). Dalam hal ini, model lokal akan dengan cepat cocok dengan satu batch itu, dan metrik akurasi lokal yang kami rata-rata akan mendekati 1,0. Dengan demikian, metrik pelatihan ini dapat dianggap sebagai tanda bahwa pelatihan sedang berlangsung, tetapi tidak lebih.

Untuk melakukan evaluasi data federasi, Anda dapat membuat perhitungan federasi lain yang dirancang untuk tujuan ini, menggunakan tff.learning.build_federated_evaluation fungsi, dan lewat dalam model konstruktor Anda sebagai argumen. Perhatikan bahwa tidak seperti dengan Federasi Averaging, di mana kita telah menggunakan MnistTrainableModel , itu sudah cukup untuk lulus MnistModel . Evaluasi tidak melakukan penurunan gradien, dan tidak perlu membuat pengoptimal.

Untuk eksperimen dan penelitian, ketika dataset tes terpusat tersedia, Federasi Belajar untuk Teks Generation menunjukkan pilihan evaluasi lain: mengambil bobot terlatih dari pembelajaran federasi, menerapkan mereka untuk model Keras standar, dan kemudian hanya memanggil tf.keras.models.Model.evaluate() pada dataset terpusat.

evaluation = tff.learning.build_federated_evaluation(MnistModel)

Anda dapat memeriksa tanda tangan tipe abstrak dari fungsi evaluasi sebagai berikut.

str(evaluation.type_signature)
'(<server_model_weights=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>@SERVER,federated_dataset={<x=float32[?,784],y=int32[?,1]>*}@CLIENTS> -> <eval=<num_examples=float32,loss=float32,accuracy=float32>,stat=<num_examples=int64>>@SERVER)'

Tidak perlu khawatir tentang rincian pada saat ini, hanya akan menyadari bahwa itu mengambil bentuk umum berikut, mirip dengan tff.templates.IterativeProcess.next tetapi dengan dua perbedaan penting. Pertama, kami tidak mengembalikan status server, karena evaluasi tidak mengubah model atau aspek status lainnya - Anda dapat menganggapnya sebagai stateless. Kedua, evaluasi hanya membutuhkan model, dan tidak memerlukan bagian lain dari status server yang mungkin terkait dengan pelatihan, seperti variabel pengoptimal.

SERVER_MODEL, FEDERATED_DATA -> TRAINING_METRICS

Mari kita lakukan evaluasi pada status terakhir yang kita capai selama pelatihan. Dalam rangka untuk mengekstrak model dilatih terbaru dari negara server, Anda cukup mengakses .model anggota, sebagai berikut.

train_metrics = evaluation(state.model, federated_train_data)

Inilah yang kami dapatkan. Perhatikan bahwa angka-angka tersebut terlihat sedikit lebih baik daripada yang dilaporkan oleh putaran terakhir pelatihan di atas. Secara konvensional, metrik pelatihan yang dilaporkan oleh proses pelatihan berulang umumnya mencerminkan kinerja model pada awal putaran pelatihan, sehingga metrik evaluasi akan selalu selangkah lebih maju.

str(train_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 4860.0), ('loss', 1.7510437), ('accuracy', 0.2788066)])), ('stat', OrderedDict([('num_examples', 4860)]))])"

Sekarang, mari kita menyusun sampel uji data gabungan dan menjalankan kembali evaluasi pada data uji. Data akan berasal dari sampel pengguna nyata yang sama, tetapi dari kumpulan data yang berbeda.

federated_test_data = make_federated_data(emnist_test, sample_clients)

len(federated_test_data), federated_test_data[0]
(10,
 <DatasetV1Adapter shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>)
test_metrics = evaluation(state.model, federated_test_data)
str(test_metrics)
"OrderedDict([('eval', OrderedDict([('num_examples', 580.0), ('loss', 1.8361608), ('accuracy', 0.2413793)])), ('stat', OrderedDict([('num_examples', 580)]))])"

Ini menyimpulkan tutorial. Kami mendorong Anda untuk bermain dengan parameter (misalnya, ukuran batch, jumlah pengguna, zaman, kecepatan pembelajaran, dll.), untuk memodifikasi kode di atas untuk mensimulasikan pelatihan pada sampel acak pengguna di setiap putaran, dan untuk menjelajahi tutorial lainnya kami telah mengembangkan.