Используйте предварительно обученную модель

В этом руководстве вы изучите пример веб-приложения, демонстрирующий передачу обучения с использованием API слоев TensorFlow.js. В примере загружается предварительно обученная модель, а затем повторно обучается модель в браузере.

Модель была предварительно обучена в Python на цифрах 0-4 набора данных классификации цифр MNIST . При переобучении (или переносе обучения) в браузере используются цифры 5-9. Пример показывает, что первые несколько слоев предварительно обученной модели можно использовать для извлечения признаков из новых данных во время обучения с переносом, что позволяет ускорить обучение на новых данных.

Пример приложения для этого руководства доступен в Интернете , поэтому вам не нужно загружать какой-либо код или настраивать среду разработки. Если вы хотите запустить код локально, выполните необязательные шаги в разделе Запуск примера локально . Если вы не хотите настраивать среду разработки, вы можете перейти к разделу «Изучение примера» .

Код примера доступен на GitHub .

(Необязательно) Запустите пример локально

Предпосылки

Чтобы запустить пример приложения локально, в вашей среде разработки должны быть установлены следующие компоненты:

Установите и запустите пример приложения

  1. Клонируйте или скачайте репозиторий tfjs-examples .
  2. Перейдите в каталог mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Установите зависимости:

    yarn
    
  4. Запустите сервер разработки:

    yarn run watch
    

Изучите пример

Откройте пример приложения . (Или, если вы запускаете пример локально, перейдите по адресу http://localhost:1234 в браузере.)

Вы должны увидеть страницу под названием MNIST CNN Transfer Learning . Следуйте инструкциям, чтобы попробовать приложение.

Вот несколько вещей, которые можно попробовать:

  • Поэкспериментируйте с различными режимами обучения и сравните потери и точность.
  • Выберите различные примеры растровых изображений и проверьте вероятности классификации. Обратите внимание, что числа в каждом примере растрового изображения представляют собой целочисленные значения в оттенках серого, представляющие пиксели изображения.
  • Отредактируйте целочисленные значения растрового изображения и посмотрите, как изменения повлияют на вероятности классификации.

Исследуйте код

Пример веб-приложения загружает модель, предварительно обученную на подмножестве набора данных MNIST. Предварительное обучение определено в программе Python: mnist_transfer_cnn.py . Программа Python выходит за рамки этого руководства, но ее стоит посмотреть, если вы хотите увидеть пример преобразования модели .

Файл index.js содержит большую часть обучающего кода для демонстрации. Когда index.js запускается в браузере, функция настройки setupMnistTransferCNN создает и инициализирует MnistTransferCNNPredictor , который инкапсулирует процедуры переобучения и прогнозирования.

Метод инициализации MnistTransferCNNPredictor.init загружает модель, загружает данные переобучения и создает тестовые данные. Вот строка , загружающая модель:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Если вы посмотрите на определение loader.loadHostedPretrainedModel , вы увидите, что оно возвращает результат вызова tf.loadLayersModel . Это API TensorFlow.js для загрузки модели, состоящей из объектов Layer.

Логика переобучения определена в MnistTransferCNNPredictor.retrainModel . Если пользователь выбрал Заморозить векторные слои в качестве режима обучения, первые 7 слоев базовой модели замораживаются, и только последние 5 слоев обучаются на новых данных. Если пользователь выбрал Повторная инициализация весов , все веса сбрасываются, и приложение эффективно обучает модель с нуля.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Затем модель компилируется , а затем обучается на тестовых данных с помощью model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Чтобы узнать больше о параметрах model.fit() , см. документацию по API .

После обучения на новом наборе данных (цифры 5-9) модель можно использовать для прогнозирования. Метод MnistTransferCNNPredictor.predict делает это с помощью model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Обратите внимание на использование tf.tidy , который помогает предотвратить утечку памяти.

Узнать больше

В этом руководстве был рассмотрен пример приложения, которое выполняет перенос обучения в браузере с помощью TensorFlow.js. Ознакомьтесь с приведенными ниже ресурсами, чтобы узнать больше о предварительно обученных моделях и переносе обучения.

TensorFlow.js

Ядро TensorFlow