Use um modelo pré-treinado

Neste tutorial, você explorará um exemplo de aplicativo Web que demonstra a aprendizagem por transferência usando a API TensorFlow.js Layers. O exemplo carrega um modelo pré-treinado e, em seguida, treina novamente o modelo no navegador.

O modelo foi pré-treinado em Python nos dígitos 0 a 4 do conjunto de dados de classificação de dígitos MNIST . A reciclagem (ou aprendizagem por transferência) no navegador usa os dígitos 5 a 9. O exemplo mostra que as primeiras camadas de um modelo pré-treinado podem ser usadas para extrair recursos de novos dados durante a aprendizagem por transferência, permitindo assim um treinamento mais rápido nos novos dados.

O aplicativo de exemplo deste tutorial está disponível online , portanto você não precisa baixar nenhum código ou configurar um ambiente de desenvolvimento. Se quiser executar o código localmente, conclua as etapas opcionais em Executar o exemplo localmente . Se não quiser configurar um ambiente de desenvolvimento, você pode pular para Explorar o exemplo .

O código de exemplo está disponível no GitHub .

(Opcional) Execute o exemplo localmente

Pré-requisitos

Para executar o aplicativo de exemplo localmente, você precisa do seguinte instalado em seu ambiente de desenvolvimento:

Instale e execute o aplicativo de exemplo

  1. Clone ou baixe o repositório tfjs-examples .
  2. Mude para o diretório mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Instale dependências:

    yarn
    
  4. Inicie o servidor de desenvolvimento:

    yarn run watch
    

Explore o exemplo

Abra o aplicativo de exemplo . (Ou, se você estiver executando o exemplo localmente, acesse http://localhost:1234 no seu navegador.)

Você deverá ver uma página intitulada MNIST CNN Transfer Learning . Siga as instruções para experimentar o aplicativo.

Aqui estão algumas coisas para tentar:

  • Experimente os diferentes modos de treinamento e compare a perda e a precisão.
  • Selecione diferentes exemplos de bitmap e inspecione as probabilidades de classificação. Observe que os números em cada exemplo de bitmap são valores inteiros em escala de cinza que representam pixels de uma imagem.
  • Edite os valores inteiros do bitmap e veja como as alterações afetam as probabilidades de classificação.

Explorar o código

O aplicativo Web de exemplo carrega um modelo que foi pré-treinado em um subconjunto do conjunto de dados MNIST. O pré-treinamento é definido em um programa Python: mnist_transfer_cnn.py . O programa Python está fora do escopo deste tutorial, mas vale a pena dar uma olhada se você quiser ver um exemplo de conversão de modelo .

O arquivo index.js contém a maior parte do código de treinamento da demonstração. Quando index.js é executado no navegador, uma função de configuração, setupMnistTransferCNN , instancia e inicializa MnistTransferCNNPredictor , que encapsula as rotinas de retreinamento e previsão.

O método de inicialização, MnistTransferCNNPredictor.init , carrega um modelo, carrega dados de retreinamento e cria dados de teste. Aqui está a linha que carrega o modelo:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Se você observar a definição de loader.loadHostedPretrainedModel , verá que ela retorna o resultado de uma chamada para tf.loadLayersModel . Esta é a API TensorFlow.js para carregar um modelo composto de objetos Layer.

A lógica de retreinamento é definida em MnistTransferCNNPredictor.retrainModel . Se o usuário tiver selecionado Congelar camadas de feição como modo de treinamento, as primeiras 7 camadas do modelo base serão congeladas e apenas as 5 camadas finais serão treinadas em novos dados. Se o usuário tiver selecionado Reinicializar pesos , todos os pesos serão redefinidos e o aplicativo treinará efetivamente o modelo do zero.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

O modelo é então compilado e treinado nos dados de teste usando model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Para saber mais sobre os parâmetros model.fit() , consulte a documentação da API .

Depois de ser treinado no novo conjunto de dados (dígitos 5 a 9), o modelo pode ser usado para fazer previsões. O método MnistTransferCNNPredictor.predict faz isso usando model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Observe o uso de tf.tidy , que ajuda a evitar vazamentos de memória.

Saber mais

Este tutorial explorou um aplicativo de exemplo que realiza aprendizagem por transferência no navegador usando TensorFlow.js. Confira os recursos abaixo para saber mais sobre modelos pré-treinados e transferência de aprendizagem.

TensorFlow.js

Núcleo do TensorFlow