В этом руководстве вы изучите пример веб-приложения, демонстрирующий передачу обучения с использованием API слоев TensorFlow.js. В примере загружается предварительно обученная модель, а затем повторно обучается модель в браузере.
Модель была предварительно обучена в Python на цифрах 0-4 набора данных классификации цифр MNIST . При переобучении (или переносе обучения) в браузере используются цифры 5-9. Пример показывает, что первые несколько слоев предварительно обученной модели можно использовать для извлечения признаков из новых данных во время обучения с переносом, что позволяет ускорить обучение на новых данных.
Пример приложения для этого руководства доступен в Интернете , поэтому вам не нужно загружать какой-либо код или настраивать среду разработки. Если вы хотите запустить код локально, выполните необязательные шаги в разделе Запуск примера локально . Если вы не хотите настраивать среду разработки, вы можете перейти к разделу «Изучение примера» .
Код примера доступен на GitHub .
(Необязательно) Запустите пример локально
Предпосылки
Чтобы запустить пример приложения локально, в вашей среде разработки должны быть установлены следующие компоненты:
- Node.js ( скачать )
- Пряжа ( установить )
Установите и запустите пример приложения
- Клонируйте или скачайте репозиторий
tfjs-examples
. Перейдите в каталог
mnist-transfer-cnn
:cd tfjs-examples/mnist-transfer-cnn
Установите зависимости:
yarn
Запустите сервер разработки:
yarn run watch
Изучите пример
Откройте пример приложения . (Или, если вы запускаете пример локально, перейдите по адресу http://localhost:1234
в браузере.)
Вы должны увидеть страницу под названием MNIST CNN Transfer Learning . Следуйте инструкциям, чтобы попробовать приложение.
Вот несколько вещей, которые можно попробовать:
- Поэкспериментируйте с различными режимами обучения и сравните потери и точность.
- Выберите различные примеры растровых изображений и проверьте вероятности классификации. Обратите внимание, что числа в каждом примере растрового изображения представляют собой целочисленные значения в оттенках серого, представляющие пиксели изображения.
- Отредактируйте целочисленные значения растрового изображения и посмотрите, как изменения повлияют на вероятности классификации.
Исследуйте код
Пример веб-приложения загружает модель, предварительно обученную на подмножестве набора данных MNIST. Предварительное обучение определено в программе Python: mnist_transfer_cnn.py
. Программа Python выходит за рамки этого руководства, но ее стоит посмотреть, если вы хотите увидеть пример преобразования модели .
Файл index.js
содержит большую часть обучающего кода для демонстрации. Когда index.js
запускается в браузере, функция настройки setupMnistTransferCNN
создает и инициализирует MnistTransferCNNPredictor
, который инкапсулирует процедуры переобучения и прогнозирования.
Метод инициализации MnistTransferCNNPredictor.init
загружает модель, загружает данные переобучения и создает тестовые данные. Вот строка , загружающая модель:
this.model = await loader.loadHostedPretrainedModel(urls.model);
Если вы посмотрите на определение loader.loadHostedPretrainedModel
, вы увидите, что оно возвращает результат вызова tf.loadLayersModel
. Это API TensorFlow.js для загрузки модели, состоящей из объектов Layer.
Логика переобучения определена в MnistTransferCNNPredictor.retrainModel
. Если пользователь выбрал Заморозить векторные слои в качестве режима обучения, первые 7 слоев базовой модели замораживаются, и только последние 5 слоев обучаются на новых данных. Если пользователь выбрал Повторная инициализация весов , все веса сбрасываются, и приложение эффективно обучает модель с нуля.
if (trainingMode === 'freeze-feature-layers') {
console.log('Freezing feature layers of the model.');
for (let i = 0; i < 7; ++i) {
this.model.layers[i].trainable = false;
}
} else if (trainingMode === 'reinitialize-weights') {
// Make a model with the same topology as before, but with re-initialized
// weight values.
const returnString = false;
this.model = await tf.models.modelFromJSON({
modelTopology: this.model.toJSON(null, returnString)
});
}
Затем модель компилируется , а затем обучается на тестовых данных с помощью model.fit()
:
await this.model.fit(this.gte5TrainData.x, this.gte5TrainData.y, {
batchSize: batchSize,
epochs: epochs,
validationData: [this.gte5TestData.x, this.gte5TestData.y],
callbacks: [
ui.getProgressBarCallbackConfig(epochs),
tfVis.show.fitCallbacks(surfaceInfo, ['val_loss', 'val_acc'], {
zoomToFit: true,
zoomToFitAccuracy: true,
height: 200,
callbacks: ['onEpochEnd'],
}),
]
});
Чтобы узнать больше о параметрах model.fit()
, см. документацию по API .
После обучения на новом наборе данных (цифры 5-9) модель можно использовать для прогнозирования. Метод MnistTransferCNNPredictor.predict
делает это с помощью model.predict()
:
// Perform prediction on the input image using the loaded model.
predict(imageText) {
tf.tidy(() => {
try {
const image = util.textToImageArray(imageText, this.imageSize);
const predictOut = this.model.predict(image);
const winner = predictOut.argMax(1);
ui.setPredictResults(predictOut.dataSync(), winner.dataSync()[0] + 5);
} catch (e) {
ui.setPredictError(e.message);
}
});
}
Обратите внимание на использование tf.tidy
, который помогает предотвратить утечку памяти.
Узнать больше
В этом руководстве был рассмотрен пример приложения, которое выполняет перенос обучения в браузере с помощью TensorFlow.js. Ознакомьтесь с приведенными ниже ресурсами, чтобы узнать больше о предварительно обученных моделях и переносе обучения.
TensorFlow.js
- Импорт модели Keras в TensorFlow.js
- Импорт модели TensorFlow в TensorFlow.js
- Готовые модели для TensorFlow.js
Ядро TensorFlow