การเรียนรู้แบบสหพันธรัฐสำหรับการสร้างข้อความ

ดูบน TensorFlow.org ทำงานใน Google Colab ดูแหล่งที่มาบน GitHub ดาวน์โหลดโน๊ตบุ๊ค

กวดวิชานี้จะสร้างขึ้นบนแนวคิดในการ เรียนรู้สหพันธ์สำหรับภาพการจำแนกประเภท กวดวิชาและแสดงให้เห็นถึงหลายวิธีที่มีประโยชน์อื่น ๆ สำหรับการเรียนรู้แบบ federated

โดยเฉพาะอย่างยิ่ง เราโหลดโมเดล Keras ที่ได้รับการฝึกอบรมมาก่อนหน้านี้ และปรับแต่งโดยใช้การฝึกแบบรวมศูนย์บนชุดข้อมูลแบบกระจายศูนย์ (จำลอง) นี่เป็นสิ่งสำคัญในทางปฏิบัติด้วยเหตุผลหลายประการ ความสามารถในการใช้แบบจำลองต่อเนื่องทำให้ง่ายต่อการผสมผสานการเรียนรู้แบบสหพันธรัฐกับแนวทาง ML อื่นๆ ต่อไปนี้จะช่วยให้การใช้งานที่หลากหลายเพิ่มมากขึ้นของรุ่นก่อนได้รับการฝึกฝน --- ตัวอย่างเช่นรูปแบบการฝึกอบรมภาษาจากรอยขีดข่วนจะไม่ค่อยมีความจำเป็นเช่นหลายรุ่นก่อนการฝึกอบรมอยู่ในขณะนี้สามารถใช้ได้อย่างกว้างขวาง (ดูเช่น TF Hub ) คุณควรเริ่มต้นจากโมเดลที่ได้รับการฝึกอบรมล่วงหน้า และปรับแต่งโดยใช้ Federated Learning แทน โดยปรับให้เข้ากับลักษณะเฉพาะของข้อมูลแบบกระจายศูนย์สำหรับแอปพลิเคชันเฉพาะ

สำหรับบทช่วยสอนนี้ เราเริ่มต้นด้วย RNN ที่สร้างอักขระ ASCII และปรับแต่งผ่านการเรียนรู้แบบรวมศูนย์ นอกจากนี้เรายังแสดงวิธีการป้อนตุ้มน้ำหนักสุดท้ายกลับไปยังโมเดล Keras ดั้งเดิม ซึ่งช่วยให้ประเมินและสร้างข้อความได้ง่ายโดยใช้เครื่องมือมาตรฐาน

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import collections
import functools
import os
import time

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()
b'Hello, World!'

โหลดโมเดลที่ฝึกไว้ล่วงหน้า

เราโหลดรูปแบบที่ผ่านการฝึกอบรมก่อนต่อไป TensorFlow กวดวิชา รุ่นข้อความโดยใช้ RNN กับการดำเนินความกระตือรือร้น แต่มากกว่าการฝึกอบรม ที่สมบูรณ์ผลงานของเชคสเปีย เราก่อนการฝึกอบรมรุ่นที่ข้อความจากชาร์ลส์ดิคเก้นส เรื่องของสองนคร และ คริสต์มาส

นอกเหนือจากการขยายคำศัพท์ เราไม่ได้แก้ไขบทช่วยสอนดั้งเดิม ดังนั้นรูปแบบเริ่มต้นนี้จึงไม่ล้ำสมัย แต่ให้การคาดคะเนที่สมเหตุสมผลและเพียงพอสำหรับจุดประสงค์ในการสอนของเรา รุ่นสุดท้ายถูกบันทึกไว้ด้วย tf.keras.models.save_model(include_optimizer=False)

เราจะใช้การเรียนรู้แบบรวมศูนย์เพื่อปรับแต่งโมเดลนี้สำหรับเช็คสเปียร์ในบทช่วยสอนนี้ โดยใช้ข้อมูลเวอร์ชันรวมที่ได้รับจาก TFF

สร้างตารางค้นหาคำศัพท์

# A fixed vocabularly of ASCII chars that occur in the works of Shakespeare and Dickens:
vocab = list('dhlptx@DHLPTX $(,048cgkoswCGKOSW[_#\'/37;?bfjnrvzBFJNRVZ"&*.26:\naeimquyAEIMQUY]!%)-159\r')

# Creating a mapping from unique characters to indices
char2idx = {u:i for i, u in enumerate(vocab)}
idx2char = np.array(vocab)

โหลดโมเดลที่ฝึกไว้ล่วงหน้าแล้วสร้างข้อความ

def load_model(batch_size):
  urls = {
      1: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel',
      8: 'https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel'}
  assert batch_size in urls, 'batch_size must be in ' + str(urls.keys())
  url = urls[batch_size]
  local_file = tf.keras.utils.get_file(os.path.basename(url), origin=url)  
  return tf.keras.models.load_model(local_file, compile=False)
def generate_text(model, start_string):
  # From https://www.tensorflow.org/tutorials/sequences/text_generation
  num_generate = 200
  input_eval = [char2idx[s] for s in start_string]
  input_eval = tf.expand_dims(input_eval, 0)
  text_generated = []
  temperature = 1.0

  model.reset_states()
  for i in range(num_generate):
    predictions = model(input_eval)
    predictions = tf.squeeze(predictions, 0)
    predictions = predictions / temperature
    predicted_id = tf.random.categorical(
        predictions, num_samples=1)[-1, 0].numpy()
    input_eval = tf.expand_dims([predicted_id], 0)
    text_generated.append(idx2char[predicted_id])

  return (start_string + ''.join(text_generated))
# Text generation requires a batch_size=1 model.
keras_model_batch1 = load_model(batch_size=1)
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch1.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step
What of TensorFlow Federated, you ask? Sall
yesterday. Received the Bailey."

"Mr. Lorry, grimmering himself, or low varked thends the winter, and the eyes of Monsieur
Defarge. "Let his mind, hon in his
life and message; four declare

โหลดและประมวลผล Federated Shakespeare Data ล่วงหน้า

tff.simulation.datasets แพคเกจให้ความหลากหลายของชุดข้อมูลที่มีการแบ่งออกเป็น "ลูกค้า" ที่ลูกค้าแต่ละสอดคล้องกับชุดข้อมูลบนอุปกรณ์เฉพาะที่อาจจะมีส่วนร่วมในการเรียนรู้แบบ federated

ชุดข้อมูลเหล่านี้ให้การกระจายข้อมูลแบบ non-IID ที่สมจริง ซึ่งจำลองความท้าทายของการฝึกอบรมเกี่ยวกับข้อมูลที่กระจายอำนาจจริงในการจำลอง บางส่วนของการประมวลผลก่อนของข้อมูลนี้ถูกทำโดยใช้เครื่องมือจาก โครงการใบ ( GitHub )

train_data, test_data = tff.simulation.datasets.shakespeare.load_data()

ชุดข้อมูลที่มีให้โดย shakespeare.load_data() ประกอบด้วยลำดับของสตริง Tensors หนึ่งสำหรับแต่ละบรรทัดที่พูดโดยตัวละครโดยเฉพาะอย่างยิ่งในการเล่นเช็คสเปียร์ คีย์ลูกค้าประกอบด้วยชื่อของการเล่นร่วมกับชื่อของตัวละครเพื่อให้ตัวอย่างเช่น MUCH_ADO_ABOUT_NOTHING_OTHELLO สอดคล้องกับเส้นสำหรับตัว Othello ในการเล่น Much Ado About Nothing โปรดทราบว่าในไคลเอนต์สถานการณ์การเรียนรู้แบบรวมศูนย์จริงจะไม่ถูกระบุหรือติดตามโดยรหัส แต่สำหรับการจำลอง จะมีประโยชน์ในการทำงานกับชุดข้อมูลที่มีคีย์

ตัวอย่างเช่น เราสามารถดูข้อมูลบางส่วนจาก King Lear:

# Here the play is "The Tragedy of King Lear" and the character is "King".
raw_example_dataset = train_data.create_tf_dataset_for_client(
    'THE_TRAGEDY_OF_KING_LEAR_KING')
# To allow for future extensions, each entry x
# is an OrderedDict with a single key 'snippets' which contains the text.
for x in raw_example_dataset.take(2):
  print(x['snippets'])
tf.Tensor(b'', shape=(), dtype=string)
tf.Tensor(b'What?', shape=(), dtype=string)

ขณะนี้เราใช้ tf.data.Dataset แปลงข้อมูลเพื่อเตรียมความพร้อมสำหรับการฝึกอบรมนี้ถ่าน RNN โหลดดังกล่าวข้างต้น

# Input pre-processing parameters
SEQ_LENGTH = 100
BATCH_SIZE = 8
BUFFER_SIZE = 100  # For dataset shuffling
# Construct a lookup table to map string chars to indexes,
# using the vocab loaded above:
table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=vocab, values=tf.constant(list(range(len(vocab))),
                                       dtype=tf.int64)),
    default_value=0)


def to_ids(x):
  s = tf.reshape(x['snippets'], shape=[1])
  chars = tf.strings.bytes_split(s).values
  ids = table.lookup(chars)
  return ids


def split_input_target(chunk):
  input_text = tf.map_fn(lambda x: x[:-1], chunk)
  target_text = tf.map_fn(lambda x: x[1:], chunk)
  return (input_text, target_text)


def preprocess(dataset):
  return (
      # Map ASCII chars to int64 indexes using the vocab
      dataset.map(to_ids)
      # Split into individual chars
      .unbatch()
      # Form example sequences of SEQ_LENGTH +1
      .batch(SEQ_LENGTH + 1, drop_remainder=True)
      # Shuffle and form minibatches
      .shuffle(BUFFER_SIZE).batch(BATCH_SIZE, drop_remainder=True)
      # And finally split into (input, target) tuples,
      # each of length SEQ_LENGTH.
      .map(split_input_target))

โปรดทราบว่าในการก่อตัวของลำดับเดิมและในรูปแบบของแบตช์ดังกล่าวข้างต้นเราจะใช้ drop_remainder=True สำหรับความเรียบง่าย ที่นี้หมายถึงว่าตัวละครใด ๆ (ลูกค้า) ที่ไม่ได้มีอย่างน้อย (SEQ_LENGTH + 1) * BATCH_SIZE ตัวอักษรของข้อความจะมีชุดข้อมูลที่ว่างเปล่า วิธีทั่วไปในการแก้ไขปัญหานี้คือการเพิ่มโทเค็นพิเศษลงในแบตช์ แล้วปิดบังการสูญเสียเพื่อไม่ให้พิจารณาโทเค็นการแพ็ด

นี้จะมีความซับซ้อนเช่นค่อนข้างดังนั้นสำหรับการกวดวิชานี้เราจะใช้สำหรับกระบวนการเต็มรูปแบบเช่นเดียวกับใน การกวดวิชามาตรฐาน อย่างไรก็ตาม ในการตั้งค่าแบบรวมศูนย์ ปัญหานี้มีความสำคัญมากกว่า เนื่องจากผู้ใช้จำนวนมากอาจมีชุดข้อมูลขนาดเล็ก

ตอนนี้เราสามารถ preprocess เรา raw_example_dataset และเช็คประเภท:

example_dataset = preprocess(raw_example_dataset)
print(example_dataset.element_spec)
(TensorSpec(shape=(8, 100), dtype=tf.int64, name=None), TensorSpec(shape=(8, 100), dtype=tf.int64, name=None))

รวบรวมแบบจำลองและทดสอบกับข้อมูลที่ประมวลผลล่วงหน้า

เราโหลดไม่ได้คอมรุ่น keras แต่เพื่อที่จะรัน keras_model.evaluate เราต้องรวบรวมกับการสูญเสียและตัวชี้วัด นอกจากนี้ เราจะคอมไพล์ในตัวเพิ่มประสิทธิภาพ ซึ่งจะใช้เป็นตัวเพิ่มประสิทธิภาพบนอุปกรณ์ในการเรียนรู้แบบรวมศูนย์

บทช่วยสอนดั้งเดิมไม่มีความแม่นยำระดับอักขระ (เศษส่วนของการคาดคะเนที่มีความเป็นไปได้สูงสุดในอักขระถัดไปที่ถูกต้อง) นี่เป็นตัวชี้วัดที่มีประโยชน์ ดังนั้นเราจึงเพิ่มเข้าไป แต่เราจำเป็นต้องกำหนดตัวชี้วัดระดับใหม่สำหรับการคาดการณ์นี้เพราะเรามีตำแหน่ง 3 (เวกเตอร์ของ logits สำหรับแต่ละ BATCH_SIZE * SEQ_LENGTH การคาดการณ์) และ SparseCategoricalAccuracy คาดว่าเพียงอันดับ 2 การคาดการณ์

class FlattenedCategoricalAccuracy(tf.keras.metrics.SparseCategoricalAccuracy):

  def __init__(self, name='accuracy', dtype=tf.float32):
    super().__init__(name, dtype=dtype)

  def update_state(self, y_true, y_pred, sample_weight=None):
    y_true = tf.reshape(y_true, [-1, 1])
    y_pred = tf.reshape(y_pred, [-1, len(vocab), 1])
    return super().update_state(y_true, y_pred, sample_weight)

ตอนนี้เราสามารถรวบรวมรูปแบบและประเมินผลที่เป็นของเรา example_dataset

BATCH_SIZE = 8  # The training and eval batch size for the rest of this tutorial.
keras_model = load_model(batch_size=BATCH_SIZE)
keras_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[FlattenedCategoricalAccuracy()])

# Confirm that loss is much lower on Shakespeare than on random data
loss, accuracy = keras_model.evaluate(example_dataset.take(5), verbose=0)
print(
    'Evaluating on an example Shakespeare character: {a:3f}'.format(a=accuracy))

# As a sanity check, we can construct some completely random data, where we expect
# the accuracy to be essentially random:
random_guessed_accuracy = 1.0 / len(vocab)
print('Expected accuracy for random guessing: {a:.3f}'.format(
    a=random_guessed_accuracy))
random_indexes = np.random.randint(
    low=0, high=len(vocab), size=1 * BATCH_SIZE * (SEQ_LENGTH + 1))
data = collections.OrderedDict(
    snippets=tf.constant(
        ''.join(np.array(vocab)[random_indexes]), shape=[1, 1]))
random_dataset = preprocess(tf.data.Dataset.from_tensor_slices(data))
loss, accuracy = keras_model.evaluate(random_dataset, steps=10, verbose=0)
print('Evaluating on completely random data: {a:.3f}'.format(a=accuracy))
Downloading data from https://storage.googleapis.com/tff-models-public/dickens_rnn.batch8.kerasmodel
16195584/16193984 [==============================] - 0s 0us/step
16203776/16193984 [==============================] - 0s 0us/step
Evaluating on an example Shakespeare character: 0.402000
Expected accuracy for random guessing: 0.012
Evaluating on completely random data: 0.011

ปรับแต่งโมเดลด้วย Federated Learning

TFF ทำให้การคำนวณ TensorFlow เป็นอนุกรมทั้งหมด เพื่อให้สามารถรันในสภาพแวดล้อมที่ไม่ใช่ Python ได้ (แม้ว่าในขณะนี้ มีเพียงรันไทม์การจำลองที่ใช้ใน Python เท่านั้น) ถึงแม้ว่าเราจะทำงานในโหมดกระตือรือร้น (TF 2.0), ขณะนี้ฉิบหาย serializes TensorFlow คำนวณโดยการสร้างปฏิบัติการที่จำเป็นภายในบริบทของการที่ " with tf.Graph.as_default() " คำสั่ง ดังนั้น เราจำเป็นต้องจัดเตรียมฟังก์ชันที่ TFF สามารถใช้เพื่อแนะนำโมเดลของเราในกราฟที่ควบคุมได้ เราทำดังนี้:

# Clone the keras_model inside `create_tff_model()`, which TFF will
# call to produce a new copy of the model inside the graph that it will 
# serialize. Note: we want to construct all the necessary objects we'll need 
# _inside_ this method.
def create_tff_model():
  # TFF uses an `input_spec` so it knows the types and shapes
  # that your model expects.
  input_spec = example_dataset.element_spec
  keras_model_clone = tf.keras.models.clone_model(keras_model)
  return tff.learning.from_keras_model(
      keras_model_clone,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[FlattenedCategoricalAccuracy()])

ตอนนี้เราพร้อมที่จะสร้างสหพันธ์ Averaging กระบวนการซ้ำซึ่งเราจะใช้ในการปรับปรุงรูปแบบ (สำหรับรายละเอียดเกี่ยวกับขั้นตอนวิธีการ Averaging สหพันธ์ดูกระดาษ การเรียนรู้การสื่อสารที่มีประสิทธิภาพของเครือข่ายการกระจายอำนาจลึกจากข้อมูล )

เราใช้โมเดล Keras ที่คอมไพล์แล้วเพื่อทำการประเมินมาตรฐาน (ไม่ใช่แบบรวมศูนย์) หลังจากการฝึกอบรมแบบรวมศูนย์ในแต่ละรอบ สิ่งนี้มีประโยชน์สำหรับวัตถุประสงค์ในการวิจัยเมื่อทำการเรียนรู้แบบสหพันธรัฐจำลองและมีชุดข้อมูลการทดสอบมาตรฐาน

ในการตั้งค่าการผลิตที่สมจริง เทคนิคเดียวกันนี้อาจใช้เพื่อนำแบบจำลองที่ได้รับการฝึกด้วยการเรียนรู้แบบสหพันธรัฐมาประเมินบนชุดข้อมูลการเปรียบเทียบแบบรวมศูนย์สำหรับวัตถุประสงค์ในการทดสอบหรือการประกันคุณภาพ

# This command builds all the TensorFlow graphs and serializes them: 
fed_avg = tff.learning.build_federated_averaging_process(
    model_fn=create_tff_model,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(lr=0.5))

นี่คือการวนรอบที่ง่ายที่สุด ซึ่งเรารัน federated averaging สำหรับหนึ่งรอบบนไคลเอนต์เดียวในแบทช์เดียว:

state = fed_avg.initialize()
state, metrics = fed_avg.next(state, [example_dataset.take(5)])
train_metrics = metrics['train']
print('loss={l:.3f}, accuracy={a:.3f}'.format(
    l=train_metrics['loss'], a=train_metrics['accuracy']))
loss=4.403, accuracy=0.132

ตอนนี้ เรามาเขียนวงจรการฝึกอบรมและการประเมินที่น่าสนใจกว่านี้เล็กน้อย

เพื่อให้การจำลองนี้ยังคงดำเนินไปอย่างรวดเร็ว เราจึงฝึกกับลูกค้าสามรายเดิมในแต่ละรอบ โดยพิจารณาเพียงสองมินิแบตช์สำหรับแต่ละรอบ

def data(client, source=train_data):
  return preprocess(source.create_tf_dataset_for_client(client)).take(5)


clients = [
    'ALL_S_WELL_THAT_ENDS_WELL_CELIA', 'MUCH_ADO_ABOUT_NOTHING_OTHELLO',
]

train_datasets = [data(client) for client in clients]

# We concatenate the test datasets for evaluation with Keras by creating a 
# Dataset of Datasets, and then identity flat mapping across all the examples.
test_dataset = tf.data.Dataset.from_tensor_slices(
    [data(client, test_data) for client in clients]).flat_map(lambda x: x)

สถานะเริ่มต้นของรูปแบบที่ผลิตโดย fed_avg.initialize() จะขึ้นอยู่กับ initializers สุ่มสำหรับรุ่น Keras ไม่น้ำหนักที่ถูกโหลดตั้งแต่ clone_model() ไม่ได้โคลนน้ำหนัก เพื่อเริ่มการฝึกจากโมเดลที่ฝึกไว้ล่วงหน้า เราตั้งค่าน้ำหนักของโมเดลในสถานะเซิร์ฟเวอร์โดยตรงจากโมเดลที่โหลด

NUM_ROUNDS = 5

# The state of the FL server, containing the model and optimization state.
state = fed_avg.initialize()

# Load our pre-trained Keras model weights into the global model state.
state = tff.learning.state_with_new_model_weights(
    state,
    trainable_weights=[v.numpy() for v in keras_model.trainable_weights],
    non_trainable_weights=[
        v.numpy() for v in keras_model.non_trainable_weights
    ])


def keras_evaluate(state, round_num):
  # Take our global model weights and push them back into a Keras model to
  # use its standard `.evaluate()` method.
  keras_model = load_model(batch_size=BATCH_SIZE)
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[FlattenedCategoricalAccuracy()])
  state.model.assign_weights_to(keras_model)
  loss, accuracy = keras_model.evaluate(example_dataset, steps=2, verbose=0)
  print('\tEval: loss={l:.3f}, accuracy={a:.3f}'.format(l=loss, a=accuracy))


for round_num in range(NUM_ROUNDS):
  print('Round {r}'.format(r=round_num))
  keras_evaluate(state, round_num)
  state, metrics = fed_avg.next(state, train_datasets)
  train_metrics = metrics['train']
  print('\tTrain: loss={l:.3f}, accuracy={a:.3f}'.format(
      l=train_metrics['loss'], a=train_metrics['accuracy']))

print('Final evaluation')
keras_evaluate(state, NUM_ROUNDS + 1)
Round 0
    Eval: loss=3.324, accuracy=0.401
    Train: loss=4.360, accuracy=0.155
Round 1
    Eval: loss=4.361, accuracy=0.049
    Train: loss=4.235, accuracy=0.164
Round 2
    Eval: loss=4.219, accuracy=0.177
    Train: loss=4.081, accuracy=0.221
Round 3
    Eval: loss=4.080, accuracy=0.174
    Train: loss=3.940, accuracy=0.226
Round 4
    Eval: loss=3.991, accuracy=0.176
    Train: loss=3.840, accuracy=0.226
Final evaluation
    Eval: loss=3.909, accuracy=0.171

ด้วยการเปลี่ยนแปลงเริ่มต้น เรายังทำการฝึกอบรมไม่เพียงพอที่จะสร้างความแตกต่างครั้งใหญ่ แต่ถ้าคุณฝึกฝนข้อมูลของเช็คสเปียร์นานขึ้น คุณควรเห็นความแตกต่างในสไตล์ของข้อความที่สร้างด้วยแบบจำลองที่อัปเดต:

# Set our newly trained weights back in the originally created model.
keras_model_batch1.set_weights([v.numpy() for v in keras_model.weights])
# Text generation requires batch_size=1
print(generate_text(keras_model_batch1, 'What of TensorFlow Federated, you ask? '))
What of TensorFlow Federated, you ask? Shalways, I will call your
compet with any city brought their faces uncompany," besumed him. "When he
sticked Madame Defarge pushed the lamps.

"Have I often but no unison. She had probably come,

นามสกุลที่แนะนำ

กวดวิชานี้เป็นเพียงขั้นตอนแรก! ต่อไปนี้เป็นแนวคิดบางประการสำหรับวิธีที่คุณอาจลองขยายสมุดบันทึกนี้:

  • เขียนวงจรการฝึกอบรมที่สมจริงยิ่งขึ้น โดยคุณจะสุ่มตัวอย่างลูกค้าเพื่อฝึกฝนแบบสุ่ม
  • ใช้ " .repeat(NUM_EPOCHS) " ในชุดข้อมูลลูกค้าที่จะลองหลายยุคสมัยของการฝึกอบรมในท้องถิ่น (เช่นเดียวกับใน McMahan et. al. ) ดูเพิ่มเติม สหพันธ์เรียนรู้สำหรับการจำแนกประเภทของภาพ ที่ไม่นี้
  • เปลี่ยนการ compile() คำสั่งเพื่อการทดสอบด้วยการใช้ขั้นตอนวิธีการเพิ่มประสิทธิภาพที่แตกต่างกันกับลูกค้า
  • ลอง server_optimizer อาร์กิวเมนต์ build_federated_averaging_process จะลองขั้นตอนวิธีการที่แตกต่างกันสำหรับการใช้การปรับปรุงรูปแบบบนเซิร์ฟเวอร์
  • ลอง client_weight_fn โต้แย้งเพื่อ build_federated_averaging_process จะลองน้ำหนักที่แตกต่างกันของลูกค้า การปรับปรุงลูกค้าน้ำหนักเริ่มต้นจากจำนวนตัวอย่างที่ลูกค้า แต่คุณสามารถทำเช่น client_weight_fn=lambda _: tf.constant(1.0)