Use a pre-trained model

In this tutorial you'll explore an example web application that demonstrates transfer learning using the TensorFlow.js Layers API. The example loads a pre-trained model and then retrains the model in the browser.

The model has been pre-trained in Python on digits 0-4 of the MNIST digits classification dataset. The retraining (or transfer learning) in the browser uses digits 5-9. The example shows that the first several layers of a pre-trained model can be used to extract features from new data during transfer learning, thus enabling faster training on the new data.

The example application for this tutorial is available online, so you don't need to download any code or set up a development environment. If you'd like to run the code locally, complete the optional steps in Run the example locally. If you don't want to set up a development environment, you can skip to Explore the example.

The example code is available on GitHub.

(Optional) Run the example locally

Prerequisites

To run the example app locally, you need the following installed in your development environment:

Install and run the example app

  1. Clone or download the tfjs-examples repository.
  2. Change into the mnist-transfer-cnn directory:

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Install dependencies:

    yarn
    
  4. Start the development server:

    yarn run watch
    

Explore the example

Open the example app. (Or, if you're running the example locally, go to http://localhost:1234 in your browser.)

You should see a page titled MNIST CNN Transfer Learning. Follow the instructions to try the app.

Here are a few things to try:

  • Experiment with the different training modes and compare loss and accuracy.
  • Select different bitmap examples and inspect the classification probabilities. Note that the numbers in each bitmap example are grayscale integer values representing pixels from an image.
  • Edit the bitmap integer values and see how the changes affect classification probabilities.

Explore the code

The example web app loads a model that has been pre-trained on a subset of the MNIST dataset. The pre-training is defined in a Python program: mnist_transfer_cnn.py. The Python program is out-of-scope for this tutorial, but it's worth looking at if you'd like to see an example of model conversion.

The index.js file contains most of the training code for the demo. When index.js runs in the browser, a setup function, setupMnistTransferCNN, instantiates and initializes MnistTransferCNNPredictor, which encapsulates the retraining and prediction routines.

The initialization method, MnistTransferCNNPredictor.init, loads a model, loads retraining data, and creates test data. Here's the line that loads the model:

this.model = await loader.loadHostedPretrainedModel(urls.model);

If you look at the definition of loader.loadHostedPretrainedModel, you'll see that it returns the result of a call to tf.loadLayersModel. This is the TensorFlow.js API for loading a model composed of Layer objects.

The retraining logic is defined in MnistTransferCNNPredictor.retrainModel. If the user has selected Freeze feature layers as the training mode, the first 7 layers of the base model are frozen, and only the final 5 layers are trained on new data. If the user has selected Reinitialize weights, all the weights are reset, and the app effectively trains the model from scratch.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

The model is then compiled, and then it's trained on the test data using model.fit():

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

To learn more about the model.fit() parameters, see the API documentation.

After being trained on the new dataset (digits 5-9), the model can be used to make predictions. The MnistTransferCNNPredictor.predict method does this using model.predict():

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Note the use of tf.tidy, which helps prevent memory leaks.

Learn more

This tutorial has explored an example app that performs transfer learning in the browser using TensorFlow.js. Check out the resources below to learn more about pre-trained models and transfer learning.

TensorFlow.js

TensorFlow Core