Set di dati

In molti modelli di machine learning, in particolare per l’apprendimento supervisionato, i set di dati sono una parte vitale del processo di formazione. Swift per TensorFlow fornisce wrapper per diversi set di dati comuni all'interno del modulo Datasets nel repository dei modelli . Questi wrapper facilitano l'uso di set di dati comuni con modelli basati su Swift e si integrano bene con il ciclo di addestramento generalizzato di Swift per TensorFlow.

Wrapper del set di dati forniti

Questi sono i wrapper del set di dati attualmente forniti all'interno del repository dei modelli:

Per utilizzare uno di questi wrapper di set di dati all'interno di un progetto Swift, aggiungi Datasets come dipendenza al tuo target Swift e importa il modulo:

import Datasets

La maggior parte dei wrapper di set di dati sono progettati per produrre batch di dati etichettati mescolati in modo casuale. Ad esempio, per utilizzare il set di dati CIFAR-10, devi prima inizializzarlo con la dimensione batch desiderata:

let dataset = CIFAR10(batchSize: 100)

Al primo utilizzo, i wrapper del set di dati Swift per TensorFlow scaricheranno automaticamente il set di dati originale, estrarranno e analizzeranno tutti gli archivi rilevanti, quindi memorizzeranno il set di dati elaborato in una directory della cache locale dell'utente. Gli usi successivi dello stesso set di dati verranno caricati direttamente dalla cache locale.

Per impostare un ciclo di addestramento manuale che coinvolga questo set di dati, utilizzeresti qualcosa di simile al seguente:

for (epoch, epochBatches) in dataset.training.prefix(100).enumerated() {
  Context.local.learningPhase = .training
  ...
  for batch in epochBatches {
    let (images, labels) = (batch.data, batch.label)
    ...
  }
}

Quanto sopra imposta un iteratore attraverso 100 epoche ( .prefix(100) ) e restituisce l'indice numerico dell'epoca corrente e una sequenza mappata pigramente sui batch mescolati che compongono quell'epoca. All'interno di ciascuna epoca di addestramento, i batch vengono ripetuti ed estratti per l'elaborazione. Nel caso del wrapper del set di dati CIFAR10 , ogni batch è un LabeledImage , che fornisce un Tensor<Float> contenente tutte le immagini di quel batch e un Tensor<Int32> con le etichette corrispondenti.

Nel caso di CIFAR-10, l'intero set di dati è piccolo e può essere caricato in memoria in una sola volta, ma per altri set di dati più grandi i batch vengono caricati pigramente dal disco ed elaborati nel punto in cui viene ottenuto ciascun batch. Ciò impedisce l'esaurimento della memoria con set di dati più grandi.

L'API di Epochs

La maggior parte di questi wrapper di set di dati sono costruiti su un'infrastruttura condivisa che abbiamo chiamato Epochs API . Epochs fornisce componenti flessibili destinati a supportare un'ampia varietà di tipi di set di dati, dal testo alle immagini e altro ancora.

Se desideri creare il tuo wrapper del set di dati Swift, molto probabilmente vorrai utilizzare l'API Epochs per farlo. Tuttavia, per i casi comuni, come i set di dati di classificazione delle immagini, consigliamo vivamente di iniziare da un modello basato su uno dei wrapper del set di dati esistenti e di modificarlo per soddisfare le proprie esigenze specifiche.

Ad esempio, esaminiamo il wrapper del set di dati CIFAR-10 e come funziona. Il nucleo del set di dati di addestramento è definito qui:

let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
  .lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
    return batches.lazy.map{
      makeBatch(samples: $0, mean: mean, standardDeviation: standardDeviation, device: device)
  }
}

Il risultato della funzione loadCIFARTrainingFiles() è una matrice di tuple (data: [UInt8], label: Int32) per ogni immagine nel set di dati di training. Questo viene quindi fornito a TrainingEpochs(samples:batchSize:entropy:) per creare una sequenza infinita di epoche con batch di batchSize . Puoi fornire il tuo generatore di numeri casuali nei casi in cui desideri un comportamento di batch deterministico, ma per impostazione predefinita viene utilizzato SystemRandomNumberGenerator .

Da lì, le mappe pigre sui batch culminano nella funzione makeBatch(samples:mean:standardDeviation:device:) . Questa è una funzione personalizzata in cui si trova la pipeline di elaborazione delle immagini effettiva per il set di dati CIFAR-10, quindi diamo un'occhiata a questo:

fileprivate func makeBatch<BatchSamples: Collection>(
  samples: BatchSamples, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?, device: Device
) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) {
  let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
  let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device)

  var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
  imageTensor /= 255.0
  if let mean = mean, let standardDeviation = standardDeviation {
    imageTensor = (imageTensor - mean) / standardDeviation
  }

  let labels = Tensor<Int32>(samples.map(\.label), on: device)
  return LabeledImage(data: imageTensor, label: labels)
}

Le due righe di questa funzione concatenano tutti i byte data dai BatchSamples in entrata in un Tensor<UInt8> che corrisponde al layout dei byte delle immagini all'interno del set di dati CIFAR-10 non elaborato. Successivamente, i canali dell'immagine vengono riordinati in modo che corrispondano a quelli previsti nei nostri modelli di classificazione delle immagini standard e i dati dell'immagine vengono riformulati in un Tensor<Float> per l'utilizzo del modello.

È possibile fornire parametri di normalizzazione opzionali per regolare ulteriormente i valori del canale dell'immagine, un processo comune durante l'addestramento di molti modelli di classificazione delle immagini. Il parametro di normalizzazione Tensor s viene creato una volta all'inizializzazione del set di dati e quindi passato a makeBatch() come ottimizzazione per impedire la creazione ripetuta di piccoli tensori temporanei con gli stessi valori.

Infine, le etichette intere vengono inserite in un Tensor<Int32> e la coppia tensore immagine/etichetta restituita in un LabeledImage . Una LabeledImage è un caso specifico di LabeledData , una struttura con dati ed etichette conformi al protocollo Collatable dell'API Eppch.

Per ulteriori esempi dell'API Epochs in diversi tipi di set di dati, puoi esaminare gli altri wrapper di set di dati all'interno del repository dei modelli.