Используйте предварительно обученную модель

В этом руководстве вы изучите пример веб-приложения, демонстрирующего трансферное обучение с использованием API слоев TensorFlow.js. В примере загружается предварительно обученная модель, а затем она повторно обучается в браузере.

Модель была предварительно обучена на Python на цифрах 0–4 из набора данных классификации цифр MNIST . Переобучение (или трансферное обучение) в браузере использует цифры 5-9. Пример показывает, что первые несколько слоев предварительно обученной модели можно использовать для извлечения признаков из новых данных во время трансферного обучения, что позволяет ускорить обучение на новых данных.

Пример приложения для этого руководства доступен в Интернете , поэтому вам не нужно загружать какой-либо код или настраивать среду разработки. Если вы хотите запустить код локально, выполните дополнительные действия, описанные в разделе «Выполнение примера локально» . Если вы не хотите настраивать среду разработки, вы можете перейти к разделу «Изучите пример» .

Код примера доступен на GitHub .

(Необязательно) Запустите пример локально.

Предварительные условия

Чтобы запустить пример приложения локально, в вашей среде разработки необходимо установить следующее:

Установите и запустите пример приложения

  1. Клонируйте или загрузите репозиторий tfjs-examples .
  2. Перейдите в каталог mnist-transfer-cnn :

    cd tfjs-examples/mnist-transfer-cnn
    
  3. Установите зависимости:

    yarn
    
  4. Запустите сервер разработки:

    yarn run watch
    

Изучите пример

Откройте пример приложения . (Или, если вы запускаете пример локально, перейдите по адресу http://localhost:1234 в своем браузере.)

Вы должны увидеть страницу под названием MNIST CNN Transfer Learning . Следуйте инструкциям, чтобы попробовать приложение.

Вот несколько вещей, которые стоит попробовать:

  • Поэкспериментируйте с различными режимами обучения и сравните потери и точность.
  • Выберите различные примеры растровых изображений и проверьте вероятности классификации. Обратите внимание, что числа в каждом примере растрового изображения представляют собой целые значения в оттенках серого, представляющие пиксели изображения.
  • Отредактируйте целочисленные значения растрового изображения и посмотрите, как изменения влияют на вероятность классификации.

Изучите код

Пример веб-приложения загружает модель, предварительно обученную на подмножестве набора данных MNIST. Предварительное обучение определяется в программе Python: mnist_transfer_cnn.py . Программа Python выходит за рамки этого руководства, но ее стоит посмотреть, если вы хотите увидеть пример преобразования модели .

Файл index.js содержит большую часть обучающего кода для демонстрации. Когда index.js запускается в браузере, функция настройки setupMnistTransferCNN создает экземпляр и инициализирует MnistTransferCNNPredictor , который инкапсулирует процедуры переобучения и прогнозирования.

Метод инициализации MnistTransferCNNPredictor.init загружает модель, загружает данные переобучения и создает тестовые данные. Вот строка , которая загружает модель:

this.model = await loader.loadHostedPretrainedModel(urls.model);

Если вы посмотрите на определение loader.loadHostedPretrainedModel , вы увидите, что оно возвращает результат вызова tf.loadLayersModel . Это API TensorFlow.js для загрузки модели, состоящей из объектов Layer.

Логика переобучения определена в MnistTransferCNNPredictor.retrainModel . Если пользователь выбрал «Заморозить векторные слои» в качестве режима обучения, первые 7 слоев базовой модели замораживаются, и только последние 5 слоев обучаются на новых данных. Если пользователь выбрал «Повторно инициализировать веса» , все веса сбрасываются, и приложение эффективно обучает модель с нуля.

if (trainingMode === 'freeze-feature-layers') {
  console.log('Freezing feature layers of the model.');
  for (let i = 0; i < 7; ++i) {
    this.model.layers[i].trainable = false;
  }
} else if (trainingMode === 'reinitialize-weights') {
  // Make a model with the same topology as before, but with re-initialized
  // weight values.
  const returnString = false;
  this.model = await tf.models.modelFromJSON({
    modelTopology: this.model.toJSON(null, returnString)
  });
}

Затем модель компилируется , а затем обучается на тестовых данных с помощью model.fit() :

await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
  batchSize: batchSize,
  epochs: epochs,
  validationData: [this.gte5TestData.x, this.gte5TestData.y],
  callbacks: [
    ui.getProgressBarCallbackConfig(epochs),
    tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
      zoomToFit: true,
      zoomToFitAccuracy: true,
      height: 200,
      callbacks: ['onEpochEnd'],
    }),
  ]
});

Подробнее о параметрах model.fit() см. в документации API .

После обучения на новом наборе данных (цифры 5–9) модель можно использовать для прогнозирования. Метод MnistTransferCNNPredictor.predict делает это с помощью model.predict() :

// Perform prediction on the input image using the loaded model.
predict(imageText) {
  tf.tidy(() => {
    try {
      const image = util.textToImageArray(imageText, this.imageSize);
      const predictOut = this.model.predict(image);
      const winner = predictOut.argMax(1);

      ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
    } catch (e) {
      ui.setPredictError(e.message);
    }
  });
}

Обратите внимание на использование tf.tidy , который помогает предотвратить утечки памяти.

Узнать больше

В этом руководстве был рассмотрен пример приложения, которое выполняет перенос обучения в браузере с использованием TensorFlow.js. Ознакомьтесь с ресурсами ниже, чтобы узнать больше о предварительно обученных моделях и трансферном обучении.

TensorFlow.js

Ядро TensorFlow