Ciclo di formazione

Quando si addestra un modello di machine learning, è comune avere un ciclo in cui i dati di addestramento vengono acquisiti (o generati), i batch vengono eseguiti attraverso un modello, i gradienti ottenuti e il modello aggiornato tramite un ottimizzatore. Sebbene sia possibile scrivere un ciclo di addestramento personalizzato per ogni applicazione di addestramento, Swift per TensorFlow fornisce un'astrazione sperimentale del ciclo di addestramento che può semplificare questo processo.

Il modulo TrainingLoop all'interno del repository dei modelli contiene la versione corrente di questo ciclo di addestramento generalizzato sperimentale. È strutturato in modo tale da integrarsi con wrapper di set di dati conformi all'API Epochs per una facile acquisizione dei dati e per automatizzare l'interazione di modelli, set di dati e ottimizzatori con backend di accelerazione per ottenere prestazioni ottimali. È possibile ottenere una forte personalizzazione del processo di formazione tramite l'uso di callback.

La maggior parte degli esempi basati su immagini nel repository dei modelli sono stati convertiti per utilizzare questa astrazione del ciclo di addestramento, così come gli esempi di addestramento del modello di testo supervisionato. Tuttavia, il ciclo di addestramento potrebbe non essere appropriato nella sua progettazione attuale per tutti i modelli di machine learning.

L'implementazione del ciclo di formazione generalizzato di Swift per TensorFlow è fortemente influenzata da Learner di fastai . Per ulteriori informazioni sulla loro progettazione, fare riferimento a "fastai: A Layered API for Deep Learning" e alla presentazione di Sylvain Gugger "Fast.ai - An infinitely personalizzabile training loop" .

Utilizzo

L'esempio ResNet-CIFAR10 fornisce una buona dimostrazione di come utilizzare nella pratica questo ciclo di addestramento. Innanzitutto, importa il modulo:

import TrainingLoop

quindi scegli un backend acceleratore configurando un Device . In questo caso, selezioneremo il backend basato su X10 XLA e utilizzeremo il primo acceleratore disponibile:

let device = Device.defaultXLA

Il passaggio successivo consiste nel configurare il set di dati, il modello e l'ottimizzatore da utilizzare con il ciclo di addestramento:

let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

e quindi impostare il ciclo di formazione:

var trainingLoop = TrainingLoop(
  training: dataset.training,
  validation: dataset.validation,
  optimizer: optimizer,
  lossFunction: softmaxCrossEntropy,
  metrics: [.accuracy])

Il ciclo di training presuppone che il set di dati che stai utilizzando sia conforme all'API Epochs e ti consente di specificare quali suddivisioni all'interno del set di dati utilizzare per il training e la convalida. Qualsiasi funzione di perdita può essere utilizzata una volta inserita in un wrapper compatibile, come softmaxCrossEntropy è qui .

Le metriche attuali che possono essere acquisite includono:

  • loss
  • accuracy
  • top5Accuracy
  • matthewsCorrelationCoefficient
  • perplexity

Infine, per eseguire la formazione, chiamare quanto segue:

try! trainingLoop.fit(&model, epochs: 10, on: device)

Ciò addestrerà il modello per 10 epoche utilizzando il backend dell'acceleratore che abbiamo specificato. Le statistiche verranno visualizzate durante l'allenamento sulla console utilizzando un messaggio animato.

Richiamate

La personalizzazione di questo ciclo di addestramento generalizzato avviene tramite l'uso di callback. Questi callback possono essere agganciati a vari punti all'interno del ciclo.

Numerosi callback integrati forniscono funzionalità che possono essere aggiunte a qualsiasi ciclo di training. Questi includono:

  • Registrazione delle statistiche in file con valori separati da virgole (CSV).
  • Regolazione del tasso di apprendimento in base a un programma personalizzato
  • Monitoraggio e rappresentazione grafica dei progressi della formazione tramite TensorBoard

Oltre a questi, puoi creare i tuoi callback personalizzati per aggiungere una gamma di funzionalità aggiuntive a un ciclo di formazione standard.

Registrazione CSV

La classe CSVLogger incapsula un callback che scriverà le statistiche di addestramento in un formato con valori separati da virgole in un file di tua scelta. Questo file inizierà con le colonne etichettate epoch , batch e qualsiasi metrica che hai abilitato nel ciclo di allenamento. Verrà quindi scritta una riga per ciascun batch, con i valori correnti di tali colonne.

Per aggiungere la registrazione CSV al ciclo di training, aggiungi qualcosa di simile a quanto segue a una serie di callback forniti ai callbacks: parametro per TrainingLoop :

try! CSVLogger(path: "file.csv").log

Ad esempio, l' esempio LeNet-MNIST lo utilizza all'interno del proprio ciclo di training.

Orari delle tariffe di apprendimento

È comune quando si addestra un modello modificare la velocità di apprendimento fornita a un ottimizzatore durante il processo di addestramento. Questo può essere semplice come una diminuzione lineare nel tempo, o complesso come cicli di riscaldamento e declino descritti da funzioni complicate.

Il callback learningRateScheduler fornisce i mezzi per descrivere le pianificazioni della frequenza di apprendimento composte da diversi segmenti, ciascuno con la propria forma distinta. Ciò si ottiene definendo un LearningRateSchedule composto da ScheduleSegment , ciascuno dei quali ha una Shape definita da una funzione, una velocità di apprendimento iniziale e una velocità di apprendimento finale.

Ad esempio, il campione BERT-CoLA utilizza un aumento lineare del tasso di apprendimento durante un periodo di riscaldamento e successivamente una diminuzione lineare. A tale scopo, la richiamata della pianificazione del tasso di apprendimento è definita come segue:

learningRateScheduler(
  schedule: makeSchedule(
    [
      ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
      ScheduleSegment(shape: linear, endRate: 0)
    ]
  )
)

I due ScheduleSegment definiscono una velocità di apprendimento che inizia da 0 e aumenta linearmente fino al peakLearningRate su una serie di 10 passaggi discreti, quindi inizia alla velocità di apprendimento finale del passaggio precedente e diminuisce linearmente fino a 0 entro la fine del processo di training.

Integrazione TensorBoard

TensorBoard è un potente strumento di visualizzazione per monitorare l'addestramento del modello, analizzare l'addestramento una volta completato o confrontare le esecuzioni dell'addestramento. Swift per TensorFlow supporta la visualizzazione TensorBoard tramite l'uso del modulo TensorBoard nel repository dei modelli, che fornisce callback che registrano i parametri di addestramento.

L'esempio GPT2-WikiText2 illustra come aggiungere la registrazione TensorBoard all'addestramento del modello. Innanzitutto, importa il modulo TensorBoard . Quindi è semplice come aggiungere tensorBoardStatisticsLogger() ai callbacks: di TrainingLoop : array.

Per impostazione predefinita, ciò registrerà ogni esecuzione della formazione all'interno di una directory run/tensorboard/stats . Per visualizzarlo all'interno di Tensorboard, esegui

tensorboard --logdir ./run/tensorboard/stats

e TensorBoard dovrebbe avviare un server locale in cui è possibile visualizzare le metriche di allenamento. I risultati dell'addestramento e della convalida devono essere visualizzati separatamente e ogni esecuzione ha un timestamp univoco per consentire un facile confronto tra più esecuzioni dello stesso modello.

Il design dell'integrazione di Swift per TensorFlow TensorBoard è stato ispirato da tensorboardX . I callback TensorBoard creano direttamente gli eventi appropriati e i buffer del protocollo di riepilogo e li scrivono in un file di registro durante l'addestramento.

Richiamate personalizzate

Oltre alle richiamate integrate descritte sopra, hai la possibilità di personalizzare la funzione dei cicli di addestramento creando le tue richiamate. Questi callback sono funzioni che hanno una firma simile alla seguente:

func customCallback<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
  if event == .updateStart {
    ...
  }
}

Il ciclo di training e lo stato associato vengono passati come primo parametro. La parte corrente del ciclo a cui sta rispondendo il callback viene fornita tramite event . L'evento del ciclo di addestramento ha uno dei seguenti stati, ciascuno corrispondente a un punto diverso nel ciclo di vita del ciclo:

  • fitStart
  • fitEnd
  • epochStart
  • epochEnd
  • trainingStart
  • trainingEnd
  • validationStart
  • validationEnd
  • batchStart
  • batchEnd
  • updateStart
  • inferencePredictionEnd

La funzione di callback può scegliere di attivare la propria logica su qualsiasi combinazione degli stati sopra indicati, il che consente di estrarre dati o controllare in altro modo il ciclo di training in molti modi.