Тренировочный цикл

При обучении модели машинного обучения обычно имеется цикл, в котором данные обучения принимаются (или генерируются), пакеты пропускаются через модель, получаются градиенты и модель обновляется с помощью оптимизатора. Хотя вы можете написать собственный цикл обучения для каждого обучающего приложения, Swift для TensorFlow предоставляет экспериментальную абстракцию цикла обучения, которая может упростить этот процесс.

Модуль TrainingLoop в репозитории моделей содержит текущую версию этого экспериментального обобщенного цикла обучения. Он структурирован таким образом, чтобы интегрироваться с оболочками наборов данных, соответствующими API Epochs, для упрощения приема данных и автоматизировать взаимодействие моделей, наборов данных и оптимизаторов с серверными модулями ускорителей для достижения оптимальной производительности. Тяжелая настройка процесса обучения может быть достигнута за счет использования обратных вызовов.

Большинство примеров на основе изображений в репозитории моделей были преобразованы для использования этой абстракции цикла обучения, а также примеров обучения контролируемой текстовой модели. Однако текущий цикл обучения может не подходить для всех моделей машинного обучения.

Реализация Swift для обобщенного цикла обучения TensorFlow находится под сильным влиянием Learner от fastai . Подробнее об их дизайне можно узнать в «fastai: многоуровневый API для глубокого обучения» и в презентации Сильвена Гуггера «Fast.ai — бесконечно настраиваемый цикл обучения» .

Использование

Пример ResNet-CIFAR10 хорошо демонстрирует, как использовать этот цикл обучения на практике. Сначала импортируйте модуль:

import TrainingLoop

затем выберите серверную часть ускорителя, настроив Device . В данном случае мы выберем бэкэнд на базе X10 XLA и воспользуемся первым доступным ускорителем:

let device = Device.defaultXLA

Следующим шагом является настройка набора данных, модели и оптимизатора для использования в цикле обучения:

let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

а затем настройте цикл обучения:

var trainingLoop = TrainingLoop(
  training: dataset.training,
  validation: dataset.validation,
  optimizer: optimizer,
  lossFunction: softmaxCrossEntropy,
  metrics: [.accuracy])

Цикл обучения предполагает, что используемый вами набор данных соответствует API Epochs, и позволяет вам указать, какие разбиения в наборе данных использовать для обучения и проверки. Любую функцию потерь можно использовать после помещения в совместимую оболочку, например softmaxCrossEntropy здесь .

Текущие показатели, которые можно фиксировать, включают:

  • loss
  • accuracy
  • top5Accuracy
  • matthewsCorrelationCoefficient
  • perplexity

Наконец, чтобы выполнить обучение, вы вызываете следующее:

try! trainingLoop.fit(&model, epochs: 10, on: device)

Это обучит модель в течение 10 эпох с использованием указанного нами бэкенда ускорителя. Статистика будет отображаться во время обучения на консоли с помощью анимированной подсказки.

Обратные вызовы

Настройка этого обобщенного цикла обучения происходит за счет использования обратных вызовов. Эти обратные вызовы можно подключить к различным точкам цикла.

Несколько встроенных обратных вызовов предоставляют функциональные возможности, которые можно добавить в любой цикл обучения. К ним относятся:

  • Запись статистики в файлы со значениями, разделенными запятыми (CSV).
  • Настройка скорости обучения по индивидуальному графику
  • Мониторинг и графическое отображение прогресса обучения с помощью TensorBoard

В дополнение к этому вы можете создавать свои собственные обратные вызовы, чтобы добавить ряд дополнительных функций к стандартному циклу обучения.

CSV-ведение журнала

Класс CSVLogger инкапсулирует обратный вызов, который записывает статистику обучения в формате значений, разделенных запятыми, в выбранный вами файл. Этот файл начинается со столбцов, помеченных как epoch , batch и любые метрики, которые вы включили в своем цикле обучения. Затем для каждой партии будет записана одна строка с текущими значениями этих столбцов.

Чтобы добавить ведение журнала CSV в цикл обучения, добавьте что-то вроде следующего в массив обратных вызовов, предоставленный для параметра callbacks: для вашего TrainingLoop :

try! CSVLogger(path: "file.csv").log

Например, образец LeNet-MNIST использует это в цикле обучения.

Графики обучения

При обучении модели часто меняется скорость обучения, предоставляемая оптимизатору в процессе обучения. Это может быть как простое линейное снижение с течением времени, так и сложное, например циклы нагрева и снижения, описываемые сложными функциями.

Обратный вызов learningRateScheduler предоставляет средства описания графиков скорости обучения, состоящих из разных сегментов, каждый из которых имеет свою собственную форму. Это достигается путем определения LearningRateSchedule состоящего из ScheduleSegment , каждый из которых имеет Shape , определенную функцией, начальной скоростью обучения и конечной скоростью обучения.

Например, образец BERT-CoLA использует линейное увеличение скорости обучения во время периода разминки и линейное снижение после этого. Для этого обратный вызов расписания скорости обучения определяется следующим образом:

learningRateScheduler(
  schedule: makeSchedule(
    [
      ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
      ScheduleSegment(shape: linear, endRate: 0)
    ]
  )
)

Два ScheduleSegment определяют скорость обучения, которая начинается с 0 и линейно увеличивается до peakLearningRate в течение серии из 10 дискретных шагов, затем начинается с конечной скорости обучения предыдущего шага и линейно уменьшается до 0 к концу процесса обучения.

Интеграция с TensorBoard

TensorBoard — это мощный инструмент визуализации для мониторинга обучения модели, анализа завершения обучения или сравнения прогонов обучения. Swift для TensorFlow поддерживает визуализацию TensorBoard за счет использования модуля TensorBoard в репозитории моделей, который предоставляет обратные вызовы, регистрирующие показатели обучения.

Пример GPT2-WikiText2 показывает, как добавить ведение журнала TensorBoard к обучению модели. Сначала импортируйте модуль TensorBoard . Тогда это так же просто, как добавить tensorBoardStatisticsLogger() в массив callbacks: TrainingLoop :.

По умолчанию каждый запуск обучения будет записываться в каталог run/tensorboard/stats . Чтобы просмотреть это в Tensorboard, запустите

tensorboard --logdir ./run/tensorboard/stats

и TensorBoard должен запустить локальный сервер, на котором вы сможете просматривать показатели обучения. Результаты обучения и проверки должны отображаться отдельно, и каждый прогон имеет уникальную временную метку, чтобы можно было легко сравнивать несколько прогонов одной и той же модели.

Дизайн Swift для интеграции TensorFlow с TensorBoard был вдохновлен tensorboardX . Обратные вызовы TensorBoard напрямую создают соответствующие буферы протоколов событий и сводных данных и записывают их в файл журнала во время обучения.

Пользовательские обратные вызовы

Помимо встроенных обратных вызовов, описанных выше, у вас есть возможность настроить функцию обучающих циклов, создав свои собственные обратные вызовы. Эти обратные вызовы представляют собой функции, имеющие сигнатуру, подобную следующей:

func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
  if event == .updateStart {
    ...
  }
}

Цикл обучения и связанное с ним состояние передаются в качестве первого параметра. Текущая часть цикла, на которую отвечает обратный вызов, предоставляется через event . Событие цикла обучения имеет одно из следующих состояний, каждое из которых соответствует отдельной точке жизненного цикла цикла:

  • fitStart
  • fitEnd
  • epochStart
  • epochEnd
  • trainingStart
  • trainingEnd
  • validationStart
  • validationEnd
  • batchStart
  • batchEnd
  • updateStart
  • inferencePredictionEnd

Ваша функция обратного вызова может активировать свою логику в любой комбинации вышеуказанных состояний, что позволяет извлекать данные из цикла обучения или иным образом управлять им разными способами.