In this tutorial you'll explore an example web application that demonstrates transfer learning using the TensorFlow.js Layers API. The example loads a pre-trained model and then retrains the model in the browser.
The model has been pre-trained in Python on digits 0-4 of the MNIST digits classification dataset. The retraining (or transfer learning) in the browser uses digits 5-9. The example shows that the first several layers of a pre-trained model can be used to extract features from new data during transfer learning, thus enabling faster training on the new data.
The example application for this tutorial is available online, so you don't need to download any code or set up a development environment. If you'd like to run the code locally, complete the optional steps in Run the example locally. If you don't want to set up a development environment, you can skip to Explore the example.
The example code is available on GitHub.
(Optional) Run the example locally
Prerequisites
To run the example app locally, you need the following installed in your development environment:
Install and run the example app
- Clone or download the
tfjs-examples
repository. Change into the
mnist-transfer-cnn
directory:cd tfjs-examples/mnist-transfer-cnn
Install dependencies:
yarn
Start the development server:
yarn run watch
Explore the example
Open the example app.
(Or, if you're running the example locally, go to http://localhost:1234
in
your browser.)
You should see a page titled MNIST CNN Transfer Learning. Follow the instructions to try the app.
Here are a few things to try:
- Experiment with the different training modes and compare loss and accuracy.
- Select different bitmap examples and inspect the classification probabilities. Note that the numbers in each bitmap example are grayscale integer values representing pixels from an image.
- Edit the bitmap integer values and see how the changes affect classification probabilities.
Explore the code
The example web app loads a model that has been pre-trained on a subset of the
MNIST dataset. The pre-training is defined in a Python program:
mnist_transfer_cnn.py
.
The Python program is out-of-scope for this tutorial, but it's worth looking at
if you'd like to see an example of model conversion.
The
index.js
file contains most of the training code for the demo. When index.js
runs in
the browser, a setup function,
setupMnistTransferCNN
,
instantiates and initializes
MnistTransferCNNPredictor
,
which encapsulates the retraining and prediction routines.
The initialization method,
MnistTransferCNNPredictor.init
,
loads a model, loads retraining data, and creates test data. Here's
the line
that loads the model:
this.model = await loader.loadHostedPretrainedModel(urls.model);
If you look at the definition of
loader.loadHostedPretrainedModel
,
you'll see that it returns the result of a call to
tf.loadLayersModel
.
This is the TensorFlow.js API for loading a model composed of Layer objects.
The retraining logic is defined in
MnistTransferCNNPredictor.retrainModel
.
If the user has selected Freeze feature layers as the training mode, the
first 7 layers of the base model are frozen, and only the final 5 layers are
trained on new data. If the user has selected Reinitialize weights, all the
weights are reset, and the app effectively trains the model from scratch.
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
The model is then
compiled,
and then it's
trained
on the test data using model.fit()
:
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
To learn more about the model.fit()
parameters, see the
API documentation.
After being trained on the new dataset (digits 5-9), the model can be used to
make predictions. The
MnistTransferCNNPredictor.predict
method does this using
model.predict()
:
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
Note the use of tf.tidy
, which
helps prevent memory leaks.
Learn more
This tutorial has explored an example app that performs transfer learning in the browser using TensorFlow.js. Check out the resources below to learn more about pre-trained models and transfer learning.
TensorFlow.js
- Importing a Keras model into TensorFlow.js
- Import a TensorFlow model into TensorFlow.js
- Pre-made models for TensorFlow.js
TensorFlow Core