Costruire il proprio algoritmo di apprendimento federato

Visualizza su TensorFlow.org Esegui in Google Colab Visualizza la fonte su GitHub Scarica taccuino

Prima di iniziare

Prima di iniziare, eseguire quanto segue per assicurarsi che l'ambiente sia configurato correttamente. Se non vedi un saluto, si prega di fare riferimento alla installazione guida per le istruzioni.

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
import tensorflow as tf
import tensorflow_federated as tff

Nei classificazione di immagini e testi di generazione di tutorial, abbiamo imparato come impostare modelli e dati condutture per Federated Learning (FL), ed eseguito la formazione federata tramite il tff.learning livello API di TFF.

Questa è solo la punta dell'iceberg quando si tratta di ricerca FL. In questo tutorial, si discute come implementare algoritmi di apprendimento federati, senza rinviare al tff.learning API. Miriamo a realizzare quanto segue:

Obiettivi:

  • Comprendere la struttura generale degli algoritmi di apprendimento federato.
  • Esplora la Federated Nucleo di TFF.
  • Usa Federated Core per implementare direttamente la media federata.

Anche se questo tutorial è a sé stante, si consiglia prima lettura la classificazione di immagini e di generazione del testo tutorial.

Preparazione dei dati di input

Per prima cosa carichiamo e preprocessiamo il set di dati EMNIST incluso in TFF. Per maggiori dettagli, vedere la classificazione di immagini tutorial.

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Per alimentare il set di dati nel nostro modello, abbiamo appiattire i dati, e convertire ogni esempio in una tupla nella forma (flattened_image_vector, label) .

NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

Ora selezioniamo un piccolo numero di client e applichiamo la pre-elaborazione sopra ai loro set di dati.

client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

Preparazione del modello

Usiamo lo stesso modello come nella classificazione delle immagini tutorial. Questo modello (implementato attraverso tf.keras ) ha un solo strato nascosto, seguito da uno strato softmax.

def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

Per poter utilizzare questo modello in TFF, ci avvolgiamo il modello Keras come tff.learning.Model . Questo ci permette di eseguire il modello in avanti all'interno di TFF, e output del modello estratto . Per maggiori dettagli, si veda anche la classificazione delle immagini tutorial.

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

Mentre abbiamo usato tf.keras di creare un tff.learning.Model , TFF supporta modelli molto più generali. Questi modelli hanno i seguenti attributi rilevanti che catturano i pesi del modello:

  • trainable_variables : Un iterable dei tensori corrispondenti a strati addestrabili.
  • non_trainable_variables : Un iterable dei tensori corrispondenti a strati non addestrabili.

Per i nostri scopi, useremo soltanto i trainable_variables . (come il nostro modello ha solo quelli!).

Costruire il proprio algoritmo di apprendimento federato

Mentre il tff.learning API permette di creare molte varianti di federati della media, ci sono altri algoritmi federati che non rientrano esattamente in questo quadro. Ad esempio, è possibile aggiungere regolarizzazione, ritaglio, o algoritmi più complessi come ad esempio la formazione GAN federata . Si può anche essere invece essere interessati a analisi federati .

Per questi algoritmi più avanzati, dovremo scrivere il nostro algoritmo personalizzato utilizzando TFF. In molti casi, gli algoritmi federati hanno 4 componenti principali:

  1. Un passaggio di trasmissione da server a client.
  2. Un passaggio di aggiornamento del client locale.
  3. Un passaggio di caricamento da client a server.
  4. Un passaggio di aggiornamento del server.

In TFF, generalmente rappresentiamo algoritmi federati come tff.templates.IterativeProcess (che noi definiamo solo un IterativeProcess in tutto). Questa è una classe che contiene initialize e next funzioni. Qui, initialize viene utilizzato per inizializzare il server, e next si esibirà un round comunicazione dell'algoritmo federata. Scriviamo uno scheletro di come dovrebbe essere il nostro processo iterativo per FedAvg.

In primo luogo, abbiamo una funzione di inizializzazione che crea semplicemente un tff.learning.Model , e restituisce i suoi pesi addestrabili.

def initialize_fn():
  model = model_fn()
  return model.trainable_variables

Questa funzione sembra buona, ma come vedremo in seguito, sarà necessario apportare una piccola modifica per renderla un "calcolo TFF".

Vogliamo anche di delineare il next_fn .

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

Ci concentreremo sull'implementazione di questi quattro componenti separatamente. Per prima cosa ci concentriamo sulle parti che possono essere implementate in puro TensorFlow, vale a dire le fasi di aggiornamento di client e server.

Blocchi TensorFlow

Aggiornamento del cliente

Useremo il nostro tff.learning.Model per fare formazione cliente essenzialmente nello stesso modo in cui si addestrare un modello tensorflow. In particolare, useremo tf.GradientTape per calcolare il gradiente su lotti di dati, quindi applicare questi gradiente utilizzando un client_optimizer . Ci concentriamo solo sui pesi allenabili.

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

Aggiornamento del server

L'aggiornamento del server per FedAvg è più semplice dell'aggiornamento del client. Implementeremo la media federata "vanilla", in cui sostituiremo semplicemente i pesi del modello server con la media dei pesi del modello client. Ancora una volta, ci concentriamo solo sui pesi allenabili.

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

Il frammento potrebbe essere semplificata semplicemente restituendo i mean_client_weights . Tuttavia, le implementazioni più avanzate di Federated uso della media mean_client_weights con tecniche più sofisticate, come la quantità di moto o di adattività.

Sfida: implementare una versione di server_update che aggiorna i pesi dei server per essere il punto medio di model_weights e mean_client_weights. (Nota: Questo tipo di approccio "punto centrale" è analogo al recente lavoro sulla ottimizzatore Lookahead !).

Finora abbiamo scritto solo codice TensorFlow puro. Questo è dovuto alla progettazione, poiché TFF ti consente di utilizzare gran parte del codice TensorFlow con cui hai già familiarità. Tuttavia, ora dobbiamo specificare la logica di orchestrazione, che è, la logica che determina ciò che le trasmissioni server al client, e ciò che gli upload dei client al server.

Ciò richiederà la Federati Nucleo di TFF.

Introduzione al Nucleo Federato

Federati nucleo (FC) è un insieme di interfacce a basso livello che servono come base per la tff.learning API. Tuttavia, queste interfacce non si limitano all'apprendimento. In effetti, possono essere utilizzati per l'analisi e molti altri calcoli su dati distribuiti.

Ad alto livello, il core federato è un ambiente di sviluppo che consente alla logica del programma espressa in modo compatto di combinare il codice TensorFlow con operatori di comunicazione distribuiti (come somme distribuite e trasmissioni). L'obiettivo è fornire a ricercatori e professionisti un controllo esplicito sulla comunicazione distribuita nei loro sistemi, senza richiedere dettagli di implementazione del sistema (come specificare scambi di messaggi di rete punto-punto).

Un punto chiave è che TFF è progettato per la tutela della privacy. Pertanto, consente un controllo esplicito su dove risiedono i dati, per prevenire l'accumulo indesiderato di dati nella posizione del server centralizzato.

Dati federati

Un concetto chiave in TFF è "dati federati", che si riferisce a una raccolta di elementi di dati ospitati su un gruppo di dispositivi in ​​un sistema distribuito (ad es. set di dati client o pesi del modello del server). Modelliamo l'intera collezione di elementi di dati su tutti i dispositivi come un singolo valore federata.

Ad esempio, supponiamo di avere dispositivi client che hanno ciascuno un galleggiante che rappresenta la temperatura di un sensore. Possiamo rappresentare come un galleggiante federata

federated_float_on_clients = tff.FederatedType(tf.float32, tff.CLIENTS)

Tipi federate sono specificati da un tipo T dei suoi costituenti utente (es. tf.float32 ) e un gruppo G di dispositivi. Ci concentreremo sui casi in cui G è o tff.CLIENTS o tff.SERVER . Un tale tipo federata è rappresentato come {T}@G , come illustrato di seguito.

str(federated_float_on_clients)
'{float32}@CLIENTS'

Perché ci preoccupiamo così tanto dei posizionamenti? Un obiettivo chiave di TFF è consentire la scrittura di codice che potrebbe essere distribuito su un vero sistema distribuito. Ciò significa che è fondamentale ragionare su quali sottoinsiemi di dispositivi eseguono quale codice e dove risiedono i diversi dati.

TFF si concentra su tre cose: i dati, dove si trova il dato, e in che modo i dati si sta trasformando. I primi due sono incapsulati nei tipi federate, mentre l'ultimo è incapsulato nei calcoli federati.

Calcoli federati

TFF è un ambiente di programmazione funzionale fortemente tipizzato le cui unità di base sono calcoli federati. Si tratta di parti di logica che accettano valori federati come input e restituiscono valori federati come output.

Ad esempio, supponiamo di voler fare la media delle temperature sui sensori dei nostri clienti. Potremmo definire quanto segue (usando il nostro float federato):

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

Si potrebbe chiedere, come è questo diverso dal tf.function decoratore in tensorflow? La risposta fondamentale è che il codice generato da tff.federated_computation è né codice tensorflow né Python; Si tratta di una specificazione di un sistema distribuito in una lingua colla indipendente dalla piattaforma interna.

Anche se questo può sembrare complicato, puoi pensare ai calcoli TFF come funzioni con firme di tipo ben definite. Queste firme di tipo possono essere interrogate direttamente.

str(get_average_temperature.type_signature)
'({float32}@CLIENTS -> float32@SERVER)'

Questo tff.federated_computation accetta argomenti di tipo federato {float32}@CLIENTS ei valori ritorni di tipo federato {float32}@SERVER . I calcoli federati possono anche andare da server a client, da client a client o da server a server. I calcoli federati possono anche essere composti come normali funzioni, purché le loro firme di tipo corrispondano.

Per sostenere lo sviluppo, TFF permette di richiamare un tff.federated_computation come una funzione Python. Ad esempio, possiamo chiamare

get_average_temperature([68.5, 70.3, 69.8])
69.53334

Calcoli non impegnativi e TensorFlow

Ci sono due restrizioni chiave di cui essere a conoscenza. In primo luogo, quando l'interprete Python incontra un tff.federated_computation decoratore, la funzione viene tracciata una volta serializzato per un utilizzo futuro. A causa della natura decentralizzata dell'apprendimento federato, questo utilizzo futuro potrebbe verificarsi altrove, ad esempio in un ambiente di esecuzione remota. Pertanto, i calcoli TFF sono fondamentalmente non ansioso. Questo comportamento è analogo a quello del tf.function decoratore nel tensorflow.

In secondo luogo, un calcolo federata essere costituito solo da operatori federate (come tff.federated_mean ), essi non possono contenere operazioni tensorflow. Codice tensorflow deve limitarsi a blocchi decorati con tff.tf_computation . La maggior parte del codice tensorflow ordinaria può essere decorato direttamente, come ad esempio la seguente funzione che prende un numero e aggiunge 0.5 ad esso.

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)

Questi hanno anche di tipo firme, ma senza posizionamenti. Ad esempio, possiamo chiamare

str(add_half.type_signature)
'(float32 -> float32)'

Qui vediamo una differenza importante tra tff.federated_computation e tff.tf_computation . Il primo ha posizionamenti espliciti, mentre il secondo no.

Possiamo usare tff.tf_computation blocchi nei calcoli federate specificando posizionamenti. Creiamo una funzione che aggiunga la metà, ma solo ai float federati presso i client. Possiamo farlo utilizzando tff.federated_map , che applica un dato tff.tf_computation , pur mantenendo il posizionamento.

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)

Questa funzione è quasi identica a add_half , eccetto che accetta solo valori con posizionamento a tff.CLIENTS e valori ritorni con lo stesso posizionamento. Possiamo vederlo nella sua firma del tipo:

str(add_half_on_clients.type_signature)
'({float32}@CLIENTS -> {float32}@CLIENTS)'

In sintesi:

  • TFF opera su valori federati.
  • Ogni valore federata ha un tipo federata, con un tipo (ad es. tf.float32 ) ed un posizionamento (es. tff.CLIENTS ).
  • Valori federati possono essere trasformate mediante calcoli federati, che devono essere decorate con tff.federated_computation e una firma tipo federata.
  • Codice tensorflow deve essere contenuto in blocchi con tff.tf_computation decoratori.
  • Questi blocchi possono quindi essere incorporati in calcoli federati.

Costruire il proprio algoritmo di apprendimento federato, rivisitato

Ora che abbiamo dato un'occhiata al Federated Core, possiamo costruire il nostro algoritmo di apprendimento federato. Ricordate che in precedenza, abbiamo definito un initialize_fn e next_fn per il nostro algoritmo. Il next_fn farà uso del client_update e server_update abbiamo definito utilizzando il codice tensorflow puro.

Tuttavia, al fine di rendere il nostro algoritmo di un calcolo federata, avremo bisogno sia la next_fn e initialize_fn ad ogni essere un tff.federated_computation .

Blocchi federati TensorFlow

Creazione del calcolo di inizializzazione

La funzione di inizializzazione sarà molto semplice: Creeremo un modello utilizzando model_fn . Tuttavia, ricordiamo che dobbiamo separare il nostro codice tensorflow utilizzando tff.tf_computation .

@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

Possiamo quindi passare questo direttamente in un calcolo federata utilizzando tff.federated_value .

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

Creazione del next_fn

Ora usiamo il nostro codice di aggiornamento client e server per scrivere l'algoritmo vero e proprio. Per prima cosa rivolgere la nostra client_update in un tff.tf_computation che accetta un set di dati client e pesi dei server, ed emette una versione aggiornata pesi client tensore.

Avremo bisogno dei tipi corrispondenti per decorare adeguatamente la nostra funzione. Fortunatamente, il tipo dei pesi del server può essere estratto direttamente dal nostro modello.

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

Diamo un'occhiata alla firma del tipo di set di dati. Ricorda che abbiamo preso 28 per 28 immagini (con etichette intere) e le abbiamo appiattite.

str(tf_dataset_type)
'<float32[?,784],int32[?,1]>*'

Possiamo anche estrarre il tipo di modello pesi utilizzando la nostra server_init funzione di cui sopra.

model_weights_type = server_init.type_signature.result

Esaminando la firma del tipo, saremo in grado di vedere l'architettura del nostro modello!

str(model_weights_type)
'<float32[784,10],float32[10]>'

Ora possiamo creare la nostra tff.tf_computation per l'aggiornamento del client.

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

La tff.tf_computation versione dell'aggiornamento del server può essere definito in modo simile, utilizzando i tipi abbiamo già estratto.

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

Ultimo, ma non meno importante, abbiamo bisogno di creare il tff.federated_computation che mette insieme tutto questo. Questa funzione accetta due valori federati, uno corrispondenti ai pesi dei server (con il posizionamento tff.SERVER ), e l'altro corrispondente alle serie di dati del cliente (con il posizionamento tff.CLIENTS ).

Nota che entrambi questi tipi sono stati definiti sopra! Abbiamo semplicemente bisogno di dare loro il corretto posizionamento utilizzando tff.FederatedType .

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

Ricordi i 4 elementi di un algoritmo FL?

  1. Un passaggio di trasmissione da server a client.
  2. Un passaggio di aggiornamento del client locale.
  3. Un passaggio di caricamento da client a server.
  4. Un passaggio di aggiornamento del server.

Ora che abbiamo costruito quanto sopra, ogni parte può essere rappresentata in modo compatto come una singola riga di codice TFF. Questa semplicità è il motivo per cui abbiamo dovuto fare molta attenzione a specificare cose come i tipi federati!

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

Ora abbiamo una tff.federated_computation sia per l'inizializzazione algoritmo, e per l'esecuzione di un passo dell'algoritmo. Per finire il nostro algoritmo, passiamo questi in tff.templates.IterativeProcess .

federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

Diamo un'occhiata alla firma tipo di initialize e next funzioni del nostro processo iterativo.

str(federated_algorithm.initialize.type_signature)
'( -> <float32[784,10],float32[10]>@SERVER)'

Questo riflette il fatto che federated_algorithm.initialize è una funzione non-arg che restituisce un modello monostrato (con una matrice di pesi 784-by-10, e 10 unità di bias).

str(federated_algorithm.next.type_signature)
'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

Qui, vediamo che federated_algorithm.next accetta un modello di server e dati dei clienti, e restituisce un modello di server aggiornato.

Valutare l'algoritmo

Facciamo qualche giro e vediamo come cambia la perdita. In primo luogo, si definirà una funzione di valutazione usando l'approccio centralizzato discusso nel secondo tutorial.

Per prima cosa creiamo un set di dati di valutazione centralizzato, quindi applichiamo la stessa preelaborazione utilizzata per i dati di addestramento.

central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

Successivamente, scriviamo una funzione che accetta uno stato del server e utilizza Keras per valutare il set di dati di test. Se si ha familiarità con tf.Keras , questo sarà tutto aspetto familiare, anche se nota l'uso di set_weights !

def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

Ora inizializziamo il nostro algoritmo e valutiamo sul set di test.

server_state = federated_algorithm.initialize()
evaluate(server_state)
2042/2042 [==============================] - 2s 767us/step - loss: 2.8479 - sparse_categorical_accuracy: 0.1027

Alleniamoci per qualche round e vediamo se cambia qualcosa.

for round in range(15):
  server_state = federated_algorithm.next(server_state, federated_train_data)
evaluate(server_state)
2042/2042 [==============================] - 2s 738us/step - loss: 2.5867 - sparse_categorical_accuracy: 0.0980

Vediamo una leggera diminuzione della funzione di perdita. Anche se il salto è piccolo, abbiamo eseguito solo 15 cicli di allenamento e su un piccolo sottoinsieme di client. Per vedere risultati migliori, potremmo dover fare centinaia se non migliaia di round.

Modificare il nostro algoritmo

A questo punto, fermiamoci a pensare a ciò che abbiamo realizzato. Abbiamo implementato la media federata direttamente combinando puro codice TensorFlow (per gli aggiornamenti client e server) con calcoli federati dal Federated Core di TFF.

Per eseguire un apprendimento più sofisticato, possiamo semplicemente modificare ciò che abbiamo sopra. In particolare, modificando il puro codice TF sopra, possiamo cambiare il modo in cui il client esegue l'addestramento o il modo in cui il server aggiorna il suo modello.

Sfida: Aggiungi clipping gradiente al client_update funzioni.

Se volessimo apportare modifiche più grandi, potremmo anche archiviare il server e trasmettere più dati. Ad esempio, il server potrebbe anche memorizzare il tasso di apprendimento del client e farlo decadere nel tempo! Nota che ciò richiederà modifiche alle firme tipi utilizzati nelle tff.tf_computation chiama sopra.

Harder Sfida: Implementare Federati della media con l'apprendimento decadimento tasso sui client.

A questo punto, potresti iniziare a renderti conto di quanta flessibilità ci sia in ciò che puoi implementare in questo framework. Per le idee (tra cui la risposta alla sfida più difficile di cui sopra) si può vedere il codice sorgente per tff.learning.build_federated_averaging_process , o controllare i vari progetti di ricerca utilizzando TFF.