Giúp bảo vệ Great Barrier Reef với TensorFlow trên Kaggle Tham Challenge

Xây dựng thuật toán học tập liên kết của riêng bạn

Xem trên TensorFlow.org Chạy trong Google Colab Xem nguồn trên GitHub Tải xuống sổ ghi chép

Trước khi chúng ta bắt đầu

Trước khi chúng tôi bắt đầu, vui lòng chạy phần sau để đảm bảo rằng môi trường của bạn được thiết lập chính xác. Nếu bạn không thấy một lời chào, xin vui lòng tham khảo các cài đặt hướng dẫn để được hướng dẫn.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import tensorflow as tf
import tensorflow_federated as tff

Trong phân loại hình ảnhvăn bản thế hệ hướng dẫn, chúng tôi đã học làm thế nào để thiết lập mô hình và dữ liệu đường ống cho Federated Learning (FL), và thực hiện đào tạo liên thông qua tff.learning lớp API của TFF.

Đây chỉ là phần nổi của tảng băng khi nói đến nghiên cứu FL. Trong hướng dẫn này, chúng tôi thảo luận làm thế nào để thực hiện thuật toán học liên mà không trì hoãn đến tff.learning API. Chúng tôi mong muốn đạt được những điều sau:

Bàn thắng:

  • Hiểu cấu trúc chung của các thuật toán học liên hợp.
  • Khám phá Federated cốt lõi của TFF.
  • Sử dụng Lõi liên kết để triển khai Trung bình liên kết trực tiếp.

Trong khi hướng dẫn này là khép kín, chúng tôi khuyên đầu tiên đọc phân loại hình ảnhthế hệ văn bản hướng dẫn.

Chuẩn bị dữ liệu đầu vào

Đầu tiên, chúng tôi tải và xử lý trước tập dữ liệu EMNIST có trong TFF. Để biết thêm chi tiết, xem phân loại hình ảnh hướng dẫn.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Để nuôi tập dữ liệu vào mô hình của chúng tôi, chúng tôi làm phẳng dữ liệu, và chuyển đổi mỗi ví dụ thành một tuple của mẫu (flattened_image_vector, label) .

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

Bây giờ chúng tôi chọn một số lượng nhỏ khách hàng và áp dụng xử lý trước ở trên cho tập dữ liệu của họ.

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

Chuẩn bị mô hình

Chúng tôi sử dụng mô hình tương tự như trong các phân loại hình ảnh hướng dẫn. Mô hình này (thực hiện thông qua tf.keras ) có một lớp ẩn duy nhất, theo sau là một lớp softmax.

def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

Để sử dụng mô hình này trong TFF, chúng tôi quấn mô hình Keras như một tff.learning.Model . Điều này cho phép chúng tôi thực hiện của mô hình đường chuyền về phía trước trong vòng TFF, và kết quả đầu ra chiết xuất mô hình . Để biết thêm chi tiết, cũng thấy phân loại hình ảnh hướng dẫn.

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Trong khi chúng tôi sử dụng tf.keras để tạo ra một tff.learning.Model , TFF hỗ trợ mô hình tổng quát hơn nhiều. Các mô hình này có các thuộc tính liên quan sau đây để nắm bắt các trọng số của mô hình:

  • trainable_variables : Một iterable của tensors tương ứng với các lớp đào tạo được.
  • non_trainable_variables : Một iterable của tensors tương ứng với các lớp không dễ huấn luyện.

Đối với mục đích của chúng tôi, chúng tôi sẽ chỉ sử dụng trainable_variables . (vì mô hình của chúng tôi chỉ có những cái đó!).

Xây dựng thuật toán học liên kết của riêng bạn

Trong khi tff.learning API cho phép một để tạo ra nhiều biến thể của Federated trung bình, có những thuật toán liên khác mà không phù hợp với gọn gàng vào khuôn khổ này. Ví dụ, bạn có thể muốn thêm quy tắc, cắt, hoặc các thuật toán phức tạp hơn như đào tạo GAN liên . Bạn cũng có thể được thay thể quan tâm trong phân tích liên .

Đối với các thuật toán nâng cao hơn này, chúng tôi sẽ phải viết thuật toán tùy chỉnh của riêng mình bằng cách sử dụng TFF. Trong nhiều trường hợp, các thuật toán liên hợp có 4 thành phần chính:

  1. Bước truyền phát từ máy chủ đến máy khách.
  2. Một bước cập nhật ứng dụng khách cục bộ.
  3. Một bước tải lên từ máy khách đến máy chủ.
  4. Một bước cập nhật máy chủ.

Trong TFF, chúng tôi thường đại diện cho các thuật toán liên như một tff.templates.IterativeProcess (mà chúng tôi gọi là chỉ là một IterativeProcess suốt). Đây là một lớp học có chứa initializenext chức năng. Ở đây, initialize được sử dụng để khởi tạo máy chủ, và next sẽ thực hiện một vòng thông tin liên lạc của thuật toán liên. Hãy viết một bản sơ lược về quy trình lặp đi lặp lại của chúng ta cho FedAvg sẽ trông như thế nào.

Đầu tiên, chúng ta có một chức năng khởi tạo mà chỉ đơn giản tạo ra một tff.learning.Model , và trả về trọng lượng khả năng huấn luyện của mình.

def initialize_fn():
  model = model_fn()
  return model.trainable_variables

Hàm này có vẻ tốt, nhưng như chúng ta sẽ thấy ở phần sau, chúng ta sẽ cần thực hiện một sửa đổi nhỏ để biến nó thành "tính toán TFF".

Chúng tôi cũng muốn phác họa next_fn .

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

Chúng tôi sẽ tập trung vào việc triển khai bốn thành phần này một cách riêng biệt. Đầu tiên chúng tôi tập trung vào các phần có thể được triển khai trong TensorFlow thuần túy, cụ thể là các bước cập nhật máy khách và máy chủ.

TensorFlow Blocks

Cập nhật khách hàng

Chúng tôi sẽ sử dụng chúng tôi tff.learning.Model để làm đào tạo khách hàng trong về cơ bản giống như cách bạn sẽ đào tạo một mô hình TensorFlow. Đặc biệt, chúng tôi sẽ sử dụng tf.GradientTape để tính toán gradient trên lô dữ liệu, sau đó áp dụng những Gradient sử dụng một client_optimizer . Chúng tôi chỉ tập trung vào trọng lượng có thể huấn luyện được.

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

Cập nhật máy chủ

Bản cập nhật máy chủ cho FedAvg đơn giản hơn bản cập nhật máy khách. Chúng tôi sẽ thực hiện tính trung bình liên kết "vani", trong đó chúng tôi chỉ cần thay thế trọng số của mô hình máy chủ bằng trung bình của trọng số mô hình khách. Một lần nữa, chúng tôi chỉ tập trung vào mức tạ có thể tập được.

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

Đoạn có thể được đơn giản hóa bằng cách đơn giản trả lại mean_client_weights . Tuy nhiên, việc triển khai tiên tiến hơn của Federated sử dụng trung bình mean_client_weights với các kỹ thuật phức tạp hơn, chẳng hạn như đà hoặc adaptivity.

Thách thức: Thực hiện một phiên bản của server_update đó cập nhật các trọng số máy chủ để là trung điểm của model_weights và mean_client_weights. (Lưu ý: Cách tiếp cận này "trung điểm" tương tự như công việc gần đây trên ưu lookahead !).

Cho đến nay, chúng tôi chỉ viết mã TensorFlow thuần túy. Đây là do thiết kế, vì TFF cho phép bạn sử dụng nhiều mã TensorFlow mà bạn đã quen thuộc. Tuy nhiên, bây giờ chúng ta phải xác định logic dàn nhạc, có nghĩa là, logic rằng mệnh lệnh gì các chương trình phát sóng máy chủ cho khách hàng, và những gì các cập nhật khách đến máy chủ.

Điều này đòi hỏi các Federated cốt lõi của TFF.

Giới thiệu về Liên kết Core

Các Federated Core (FC) là một tập hợp các giao diện cấp thấp phục vụ như là nền tảng cho các tff.learning API. Tuy nhiên, những giao diện này không giới hạn trong việc học. Trên thực tế, chúng có thể được sử dụng để phân tích và nhiều phép tính khác trên dữ liệu phân tán.

Ở cấp độ cao, lõi liên hợp là một môi trường phát triển cho phép logic chương trình được diễn đạt gọn nhẹ để kết hợp mã TensorFlow với các toán tử truyền thông phân tán (chẳng hạn như tổng phân phối và chương trình phát sóng). Mục đích là cung cấp cho các nhà nghiên cứu và người thực hành quyền kiểm soát đối với giao tiếp phân tán trong hệ thống của họ mà không yêu cầu chi tiết triển khai hệ thống (chẳng hạn như chỉ định trao đổi thông điệp mạng điểm-điểm).

Một điểm chính là TFF được thiết kế để bảo vệ quyền riêng tư. Do đó, nó cho phép kiểm soát rõ ràng nơi dữ liệu cư trú, để ngăn chặn sự tích tụ dữ liệu không mong muốn tại vị trí máy chủ tập trung.

Dữ liệu liên kết

Một khái niệm chính trong TFF là "dữ liệu liên kết", dùng để chỉ tập hợp các mục dữ liệu được lưu trữ trên một nhóm thiết bị trong hệ thống phân tán (ví dụ: bộ dữ liệu máy khách hoặc trọng số mô hình máy chủ). Chúng tôi mô hình toàn bộ bộ sưu tập của các mục dữ liệu trên tất cả các thiết bị như một giá trị liên duy nhất.

Ví dụ: giả sử chúng ta có các thiết bị khách mà mỗi thiết bị có một phao biểu thị nhiệt độ của cảm biến. Chúng ta có thể đại diện cho nó như là một phao liên bởi

federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

Loại Federated được quy định bởi một loại T của các thành viên của nó (ví dụ. tf.float32 ) và một nhóm G của các thiết bị. Chúng tôi sẽ tập trung vào các trường hợp G là một trong hai tff.CLIENTS hoặc tff.SERVER . Một loại liên đó được biểu diễn dưới dạng {T}@G , như hình dưới đây.

str(federated_float_on_clients)
'{float32}@CLIENTS'

Tại sao chúng ta lại quan tâm nhiều đến các vị trí? Mục tiêu chính của TFF là cho phép viết mã có thể được triển khai trên một hệ thống phân tán thực. Điều này có nghĩa là điều quan trọng là phải suy luận về việc tập hợp con của thiết bị nào thực thi mã nào và vị trí của các phần dữ liệu khác nhau.

TFF tập trung vào ba điều: dữ liệu, nơi dữ liệu được đặt, và làm thế nào dữ liệu đang được chuyển đổi. Hai đầu tiên được đóng gói trong các loại liên kết, trong khi người cuối cùng được đóng gói trong tính toán liên.

Tính toán liên hợp

TFF là một môi trường lập trình chức năng mạnh mẽ, đánh máy mà đơn vị cơ bản là tính toán liên. Đây là những phần logic chấp nhận các giá trị được liên kết làm đầu vào và trả về các giá trị được liên kết dưới dạng đầu ra.

Ví dụ: giả sử chúng tôi muốn tính trung bình nhiệt độ trên các cảm biến khách hàng của chúng tôi. Chúng tôi có thể xác định những điều sau (sử dụng float liên kết của chúng tôi):

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

Bạn có thể hỏi, làm thế nào là khác nhau này từ tf.function trang trí trong TensorFlow? Câu trả lời chính là các mã được tạo bởi tff.federated_computation không phải là TensorFlow hay Python mã; Đó là một đặc điểm kỹ thuật của một hệ thống phân phối trong một nền tảng độc lập ngôn ngữ keo nội bộ.

Mặc dù điều này nghe có vẻ phức tạp, nhưng bạn có thể coi các phép tính TFF như là các hàm với các chữ ký kiểu được xác định rõ ràng. Những chữ ký kiểu này có thể được truy vấn trực tiếp.

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

Đây tff.federated_computation chấp nhận các tham số kiểu liên {float32}@CLIENTS , và trả về giá trị kiểu liên {float32}@SERVER . Các phép tính liên kết cũng có thể đi từ máy chủ đến máy khách, từ máy khách đến máy khách hoặc từ máy chủ đến máy chủ. Các phép tính liên hợp cũng có thể được cấu tạo giống như các hàm bình thường, miễn là các chữ ký kiểu của chúng khớp với nhau.

Để hỗ trợ phát triển, TFF cho phép bạn gọi một tff.federated_computation như một hàm Python. Ví dụ, chúng ta có thể gọi

get_average_temperature([68.5, 70.3, 69.8])
69.53334

Tính toán không háo hức và TensorFlow

Có hai hạn chế chính cần lưu ý. Đầu tiên, khi trình thông dịch Python gặp một tff.federated_computation trang trí, chức năng được bắt nguồn từ một lần và đăng để sử dụng trong tương lai. Do tính chất phi tập trung của Học liên kết, việc sử dụng trong tương lai này có thể xảy ra ở những nơi khác, chẳng hạn như môi trường thực thi từ xa. Do đó, TFF tính toán về cơ bản không háo hức. Hành vi này có phần tương tự như của các tf.function trang trí trong TensorFlow.

Thứ hai, một tính toán liên chỉ có thể bao gồm các nhà khai thác liên (như tff.federated_mean ), họ không thể chứa các hoạt động TensorFlow. Đang TensorFlow phải được giới hạn trong các khối trang trí với tff.tf_computation . Hầu hết các mã TensorFlow bình thường có thể được trang trí trực tiếp, chẳng hạn như chức năng sau đây mà phải mất một số và thêm 0.5 đến nó.

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

Đây cũng có chữ ký kiểu, nhưng không có vị trí. Ví dụ, chúng ta có thể gọi

str(add_half.type_signature)
'(float32 -> float32)'

Ở đây chúng ta thấy một sự khác biệt quan trọng giữa tff.federated_computationtff.tf_computation . Cái trước có vị trí rõ ràng, trong khi cái sau thì không.

Chúng ta có thể sử dụng tff.tf_computation khối trong tính toán liên bằng cách xác định vị trí. Hãy tạo một hàm bổ sung một nửa, nhưng chỉ đối với các float được liên kết tại các máy khách. Chúng ta có thể làm điều này bằng cách sử dụng tff.federated_map , áp dụng một định tff.tf_computation , trong khi vẫn giữ vị trí đó.

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

Chức năng này gần giống như add_half , ngoại trừ việc nó chỉ chấp nhận các giá trị với vị trí tại tff.CLIENTS , và trả về giá trị với cùng một vị trí. Chúng ta có thể thấy điều này trong chữ ký kiểu của nó:

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

Tóm tắt:

  • TFF hoạt động dựa trên các giá trị liên kết.
  • Mỗi giá trị liên có một kiểu liên kết, với một kiểu (ví dụ. tf.float32 ) và một vị trí (ví dụ. tff.CLIENTS ).
  • Giá trị Federated có thể được chuyển sử dụng tính liên kết, mà phải được trang trí với tff.federated_computation và một loại chữ ký liên.
  • Đang TensorFlow phải được chứa trong các khối với tff.tf_computation trang trí.
  • Các khối này sau đó có thể được kết hợp vào các phép tính liên hợp.

Xây dựng thuật toán Học liên kết của riêng bạn, đã xem lại

Bây giờ chúng ta đã có cái nhìn sơ lược về Lõi liên kết, chúng ta có thể xây dựng thuật toán học liên kết của riêng mình. Hãy nhớ rằng trên, chúng ta định nghĩa một initialize_fnnext_fn cho thuật toán của chúng tôi. Các next_fn sẽ tận dụng các client_updateserver_update chúng ta định nghĩa sử dụng mã TensorFlow tinh khiết.

Tuy nhiên, để thực hiện thuật toán của chúng tôi một tính liên kết, chúng tôi sẽ cần cả next_fninitialize_fn từng là một tff.federated_computation .

Khối liên kết TensorFlow

Tạo tính toán khởi tạo

Chức năng khởi tạo sẽ được khá đơn giản: Chúng tôi sẽ tạo ra một mô hình sử dụng model_fn . Tuy nhiên, hãy nhớ rằng chúng ta phải tách ra mã TensorFlow của chúng tôi sử dụng tff.tf_computation .

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

Sau đó chúng tôi có thể chuyển thông tin này trực tiếp vào một tính toán liên sử dụng tff.federated_value .

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

Tạo next_fn

Bây giờ chúng tôi sử dụng mã cập nhật máy khách và máy chủ của mình để viết thuật toán thực tế. Đầu tiên chúng ta sẽ lần lượt của chúng tôi client_update thành một tff.tf_computation chấp nhận một tập hợp dữ liệu khách hàng và trọng lượng máy chủ, và kết quả đầu ra một khối lượng khách hàng được cập nhật tensor.

Chúng ta sẽ cần những loại tương ứng để trang trí đúng chức năng của chúng ta. May mắn thay, loại trọng số máy chủ có thể được trích xuất trực tiếp từ mô hình của chúng tôi.

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

Hãy xem xét chữ ký loại tập dữ liệu. Hãy nhớ rằng chúng tôi đã chụp 28 x 28 hình ảnh (với nhãn số nguyên) và làm phẳng chúng.

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

Chúng tôi cũng có thể trích xuất các loại trọng số mô hình bằng cách sử dụng của chúng tôi server_init chức năng trên.

model_weights_type = server_init.type_signature.result

Kiểm tra chữ ký kiểu, chúng ta sẽ có thể thấy kiến ​​trúc của mô hình của chúng ta!

str(model_weights_type)
'<float32[784,10],float32[10]>'

Bây giờ chúng ta có thể tạo chúng tôi tff.tf_computation cho bản cập nhật client.

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

Các tff.tf_computation phiên bản cập nhật máy chủ có thể được định nghĩa theo cách tương tự, sử dụng các loại chúng tôi đã trích xuất.

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

Cuối cùng nhưng không kém, chúng ta cần phải tạo ra các tff.federated_computation đó sẽ đem này tất cả cùng nhau. Chức năng này sẽ chấp nhận hai giá trị liên kết, một tương ứng với trọng lượng máy chủ (với vị trí tff.SERVER ), và người kia tương ứng với bộ dữ liệu khách hàng (với vị trí tff.CLIENTS ).

Lưu ý rằng cả hai loại này đã được định nghĩa ở trên! Chúng tôi chỉ đơn giản là cần phải cung cấp cho họ những vị trí thích hợp sử dụng tff.FederatedType .

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

Hãy nhớ 4 yếu tố của một thuật toán FL?

  1. Bước truyền phát từ máy chủ đến máy khách.
  2. Một bước cập nhật ứng dụng khách cục bộ.
  3. Một bước tải lên từ máy khách đến máy chủ.
  4. Một bước cập nhật máy chủ.

Bây giờ chúng ta đã xây dựng phần trên, mỗi phần có thể được trình bày nhỏ gọn dưới dạng một dòng mã TFF. Sự đơn giản này là lý do tại sao chúng tôi phải cẩn thận hơn để chỉ định những thứ chẳng hạn như các loại liên kết!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

Bây giờ chúng ta có một tff.federated_computation cho cả việc khởi tạo thuật toán, và cho chạy một bước của thuật toán. Để kết thúc thuật toán của chúng tôi, chúng tôi vượt qua những thành tff.templates.IterativeProcess .

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

Hãy nhìn vào kiểu chữ ký của người initializenext chức năng của quá trình lặp đi lặp lại của chúng tôi.

str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'

Điều này phản ánh một thực tế rằng federated_algorithm.initialize là một hàm không-arg mà trả về một mô hình đơn lớp (với một ma trận trọng lượng 784-by-10, và 10 đơn vị thiên vị).

str(federated_algorithm.next.type_signature)
'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

Ở đây, chúng ta thấy rằng federated_algorithm.next chấp nhận một mô hình máy chủ và dữ liệu khách hàng, và trả về một mô hình máy chủ cập nhật.

Đánh giá thuật toán

Hãy chạy một vài vòng, và xem sự mất mát thay đổi như thế nào. Thứ nhất, chúng tôi sẽ xác định một chức năng đánh giá sử dụng phương pháp tập trung thảo luận trong hướng dẫn thứ hai.

Đầu tiên, chúng tôi tạo một tập dữ liệu đánh giá tập trung và sau đó áp dụng cùng một quy trình xử lý trước mà chúng tôi đã sử dụng cho dữ liệu đào tạo.

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

Tiếp theo, chúng tôi viết một hàm chấp nhận trạng thái máy chủ và sử dụng Keras để đánh giá trên tập dữ liệu thử nghiệm. Nếu bạn đã quen thuộc với tf.Keras , điều này sẽ tất cả cái nhìn quen thuộc, mặc dù lưu ý việc sử dụng set_weights !

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

Bây giờ, hãy khởi tạo thuật toán của chúng tôi và đánh giá trên tập thử nghiệm.

server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 2s 767us/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027

Hãy tập luyện vài hiệp và xem có gì thay đổi không nhé.

for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 2s 738us/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980

Chúng tôi thấy hàm mất mát giảm nhẹ. Mặc dù bước nhảy nhỏ, nhưng chúng tôi chỉ thực hiện 15 vòng huấn luyện và trên một nhóm nhỏ khách hàng. Để xem kết quả tốt hơn, chúng ta có thể phải thực hiện hàng trăm, nếu không muốn nói là hàng nghìn vòng.

Sửa đổi thuật toán của chúng tôi

Tại thời điểm này, chúng ta hãy dừng lại và nghĩ về những gì chúng ta đã đạt được. Chúng tôi đã triển khai Trung bình liên kết trực tiếp bằng cách kết hợp mã TensorFlow thuần túy (cho các bản cập nhật máy khách và máy chủ) với các tính toán được liên kết từ Lõi liên kết của TFF.

Để thực hiện việc học ngụy biện hơn, chúng ta có thể chỉ cần thay đổi những gì chúng ta có ở trên. Đặc biệt, bằng cách chỉnh sửa mã TF thuần túy ở trên, chúng tôi có thể thay đổi cách máy khách thực hiện đào tạo hoặc cách máy chủ cập nhật mô hình của nó.

Thách thức: Thêm clipping dốc đến client_update chức năng.

Nếu chúng tôi muốn thực hiện các thay đổi lớn hơn, chúng tôi cũng có thể có máy chủ lưu trữ và phát nhiều dữ liệu hơn. Ví dụ: máy chủ cũng có thể lưu trữ tốc độ học tập của khách hàng và làm cho nó giảm dần theo thời gian! Lưu ý rằng điều này sẽ đòi hỏi thay đổi đối với chữ ký loại được sử dụng trong tff.tf_computation gọi trên.

Harder Thách thức: Thực hiện Federated trung bình với học phân rã tốc độ trên các máy khách.

Tại thời điểm này, bạn có thể bắt đầu nhận ra mức độ linh hoạt trong những gì bạn có thể triển khai trong khuôn khổ này. Đối với những ý tưởng (kể cả câu trả lời cho những thách thức khó khăn hơn ở trên) bạn sẽ nhìn thấy mã nguồn cho tff.learning.build_federated_averaging_process , hoặc kiểm tra khác nhau đề tài nghiên cứu sử dụng TFF.