Train a model using a web worker

In this tutorial, you'll explore an example web application that uses a web worker to train a Recurrent Neural Network (RNN) to do integer addition. The example app doesn't explicitly define the addition operator. Instead, it trains the RNN using example sums.

Of course, this is not the most efficient way to add two integers! But the tutorial demonstrates an important technique in web ML: how to perform long-running computations without blocking the main thread, which handles the UI logic.

The example application for this tutorial is available online, so you don't need to download any code or set up a development environment. If you'd like to run the code locally, complete the optional steps in Run the example locally. If you don't want to set up a development environment, you can skip to Explore the example.

The example code is available on GitHub.

(Optional) Run the example locally

Prerequisites

To run the example app locally, you need the following installed in your development environment:

Install and run the example app

  1. Clone or download the tfjs-examples repository.
  2. Change into the addition-rnn-webworker directory:

    cd tfjs-examples/addition-rnn-webworker
    
  3. Install dependencies:

    yarn
    
  4. Start the development server:

    yarn run watch
    

Explore the example

Open the example app. (Or, if you're running the example locally, go to http://localhost:1234 in your browser.)

You should see a page titled TensorFlow.js: Addition RNN. Follow the instructions to try the app.

Using the web form, you can update some of the parameters used to train the model, including the following:

  • Digits: The maximum number of digits in the terms to be added.
  • Training Size: The number of training examples to generate.
  • RNN Type: One of SimpleRNN, GRU, or LSTM.
  • RNN Hidden Layer Size: Dimensionality of the output space (must be a positive integer).
  • Batch Size: Number of samples per gradient update.
  • Train Iterations: Number of times to train the model by invoking model.fit()
  • # of test examples: Number of example strings (for example, 27+41) to generate.

Try training the model with different parameters, and see if you can improve the accuracy of predictions for various sets of digits. Also notice how model fit time is affected by different parameters.

Explore the code

The example app demonstrates some of the parameters that you can configure for training an RNN. It also demonstrates the use of a web worker to train a model off the main thread. Web workers are important in web ML because they let you run computationally expensive training tasks on a background thread, thereby avoiding potentially user-impacting performance issues on the main thread. The main and worker threads communicate with each other through message events.

To learn more about web workers, see Web Workers API and Using Web Workers.

The main module for the example app is index.js. The index.js script creates a web worker that runs the worker.js module:

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js is largely composed of a single function, runAdditionRNNDemo, that handles form submission, processes form data, passes the form data to the worker, waits for the worker to train the model and return results, and then displays the results on the page.

To send the form data to the worker, the script invokes postMessage on the worker:

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

The worker listens for this message and passes the form data to functions that prepare the data and start the training:

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

During training, the worker can send two different message types, one with isPredict set to true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

and the other with isPredict set to false.

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

When the UI thread (index.js) handles message events, it checks the isPredict flag to determine the shape of data returned from the worker. If isPredict is true, the data should represent a prediction, and the script updates the page using tfjs-vis. If isPredict is false, the script runs a block of code that assumes that the data represents examples. It wraps the data in HTML and inserts the HTML into the page.

What's next

This tutorial has provided an example of using a web worker to avoid blocking the UI thread with a long-running training process. To learn more about the benefits of doing expensive computation on a background thread, see Use web workers to run JavaScript off the browser's main thread.

To learn more about training a TensorFlow.js model, see Training models.