Addestra un modello utilizzando un web worker

In questo tutorial esplorerai un'applicazione Web di esempio che utilizza un lavoratore Web per addestrare una rete neurale ricorrente (RNN) a eseguire l'addizione di numeri interi. L'app di esempio non definisce in modo esplicito l'operatore di addizione. Invece, addestra l'RNN utilizzando somme di esempio.

Naturalmente, questo non è il modo più efficiente per sommare due numeri interi! Ma il tutorial illustra una tecnica importante nel Web ML: come eseguire calcoli a esecuzione prolungata senza bloccare il thread principale, che gestisce la logica dell'interfaccia utente.

L'applicazione di esempio per questo tutorial è disponibile online , quindi non è necessario scaricare alcun codice o configurare un ambiente di sviluppo. Se desideri eseguire il codice localmente, completa i passaggi facoltativi in ​​Eseguire l'esempio localmente . Se non desideri configurare un ambiente di sviluppo, puoi passare a Esplora l'esempio .

Il codice di esempio è disponibile su GitHub .

(Facoltativo) Esegui l'esempio localmente

Prerequisiti

Per eseguire l'app di esempio localmente, è necessario che sia installato quanto segue nel tuo ambiente di sviluppo:

Installa ed esegui l'app di esempio

  1. Clona o scarica il repository tfjs-examples .
  2. Passare alla directory addition-rnn-webworker :

    cd tfjs-examples/addition-rnn-webworker
    
  3. Installa le dipendenze:

    yarn
    
  4. Avviare il server di sviluppo:

    yarn run watch
    

Esplora l'esempio

Apri l'app di esempio . (Oppure, se stai eseguendo l'esempio localmente, vai a http://localhost:1234 nel tuo browser.)

Dovresti vedere una pagina intitolata TensorFlow.js: Addition RNN . Segui le istruzioni per provare l'app.

Utilizzando il modulo Web, è possibile aggiornare alcuni dei parametri utilizzati per addestrare il modello, inclusi i seguenti:

  • Cifre : il numero massimo di cifre nei termini da aggiungere.
  • Dimensioni formazione : il numero di esempi di formazione da generare.
  • Tipo RNN : uno tra SimpleRNN , GRU o LSTM .
  • RNN Hidden Layer Size : dimensionalità dello spazio di output (deve essere un numero intero positivo).
  • Dimensione batch : numero di campioni per aggiornamento del gradiente.
  • Iterazioni di training : numero di volte per addestrare il modello invocando model.fit()
  • N. di esempi di test : numero di stringhe di esempio (ad esempio, 27+41 ) da generare.

Prova ad addestrare il modello con parametri diversi e vedi se riesci a migliorare la precisione delle previsioni per vari insiemi di cifre. Si noti inoltre come il tempo di adattamento del modello sia influenzato da diversi parametri.

Esplora il codice

L'app di esempio mostra alcuni dei parametri che puoi configurare per l'addestramento di un RNN. Dimostra inoltre l'utilizzo di un web workper per addestrare un modello dal thread principale. I Web Worker sono importanti nel Web ML perché consentono di eseguire attività di formazione costose dal punto di vista computazionale su un thread in background, evitando così problemi di prestazioni che potrebbero influire sull'utente nel thread principale. Il thread principale e quello di lavoro comunicano tra loro tramite eventi di messaggio.

Per ulteriori informazioni sui Web Worker, consulta API Web Worker e Utilizzo dei Web Worker .

Il modulo principale per l'app di esempio è index.js . Lo script index.js crea un web lavoratore che esegue il modulo worker.js :

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js è in gran parte composto da una singola funzione, runAdditionRNNDemo , che gestisce l'invio del modulo, elabora i dati del modulo, passa i dati del modulo al lavoratore, attende che il lavoratore addestri il modello e restituisca i risultati, quindi visualizza i risultati sulla pagina .

Per inviare i dati del modulo al lavoratore, lo script richiama postMessage sul lavoratore:

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

Il lavoratore ascolta questo messaggio e passa i dati del modulo alle funzioni che preparano i dati e avviano l'addestramento:

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

Durante il training, l'operatore può inviare due diversi tipi di messaggio, uno con isPredict impostato su true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

e l'altro con isPredict impostato su false .

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

Quando il thread dell'interfaccia utente ( index.js ) gestisce gli eventi del messaggio, controlla il flag isPredict per determinare la forma dei dati restituiti dal lavoratore. Se isPredict è vero, i dati dovrebbero rappresentare una previsione e lo script aggiorna la pagina utilizzando tfjs-vis . Se isPredict è false, lo script esegue un blocco di codice che presuppone che i dati rappresentino esempi. Avvolge i dati in HTML e inserisce l'HTML nella pagina.

Qual è il prossimo

Questo tutorial ha fornito un esempio di utilizzo di un web workper evitare di bloccare il thread dell'interfaccia utente con un processo di formazione a lunga esecuzione. Per ulteriori informazioni sui vantaggi derivanti dall'esecuzione di calcoli costosi su un thread in background, vedere Utilizzare i web work per eseguire JavaScript dal thread principale del browser .

Per ulteriori informazioni sull'addestramento di un modello TensorFlow.js, consulta Modelli di addestramento .