Zbiory danych

W wielu modelach uczenia maszynowego, zwłaszcza w przypadku uczenia się nadzorowanego, zbiory danych stanowią istotną część procesu szkoleniowego. Swift dla TensorFlow zapewnia opakowania dla kilku popularnych zestawów danych w module Datasets w repozytorium modeli . Te opakowania ułatwiają korzystanie z typowych zestawów danych w modelach opartych na języku Swift i dobrze integrują się z uogólnioną pętlą szkoleniową Swift for TensorFlow.

Dostarczone opakowania zestawu danych

Oto aktualnie dostępne opakowania zbioru danych w repozytorium modeli:

Aby użyć jednego z tych opakowań zbioru danych w projekcie Swift, dodaj Datasets jako zależność do celu Swift i zaimportuj moduł:

import Datasets

Większość opakowań zbiorów danych jest zaprojektowana tak, aby generować losowo przetasowane partie oznaczonych etykietami danych. Na przykład, aby użyć zbioru danych CIFAR-10, należy najpierw zainicjować go żądaną wielkością partii:

let dataset = CIFAR10(batchSize: 100)

Przy pierwszym użyciu opakowania zestawu danych Swift for TensorFlow automatycznie pobiorą oryginalny zestaw danych, wyodrębnią i przeanalizują wszystkie odpowiednie archiwa, a następnie zapiszą przetworzony zestaw danych w lokalnym katalogu pamięci podręcznej użytkownika. Kolejne użycia tego samego zestawu danych będą ładowane bezpośrednio z lokalnej pamięci podręcznej.

Aby skonfigurować ręczną pętlę szkoleniową obejmującą ten zbiór danych, możesz użyć czegoś takiego:

for (epoch, epochBatches) in dataset.training.prefix(100).enumerated() {
  Context.local.learningPhase = .training
  ...
  for batch in epochBatches {
    let (images, labels) = (batch.data, batch.label)
    ...
  }
}

Powyższe konfiguruje iterator przez 100 epok ( .prefix(100) ) i zwraca indeks liczbowy bieżącej epoki oraz leniwie odwzorowaną sekwencję na przetasowanych partiach tworzących tę epokę. W każdej epoce szkoleniowej partie są poddawane iteracji i wyodrębniane do przetworzenia. W przypadku opakowania zestawu danych CIFAR10 każda partia to LabeledImage , który udostępnia Tensor<Float> zawierający wszystkie obrazy z tej partii oraz Tensor<Int32> z pasującymi etykietami.

W przypadku CIFAR-10 cały zbiór danych jest mały i można go załadować do pamięci na raz, natomiast w przypadku innych większych zbiorów danych partie są ładowane leniwie z dysku i przetwarzane w momencie uzyskania każdej partii. Zapobiega to wyczerpaniu pamięci w przypadku większych zestawów danych.

Interfejs API epok

Większość opakowań zbioru danych opiera się na współdzielonej infrastrukturze, którą nazwaliśmy interfejsem API Epochs . Epochs zapewnia elastyczne komponenty przeznaczone do obsługi szerokiej gamy typów zbiorów danych, od tekstu po obrazy i nie tylko.

Jeśli chcesz utworzyć własne opakowanie zbioru danych Swift, najprawdopodobniej będziesz chciał użyć do tego interfejsu API Epochs. Jednak w typowych przypadkach, takich jak zbiory danych klasyfikacji obrazów, zdecydowanie zalecamy rozpoczęcie od szablonu opartego na jednym z istniejących opakowań zbioru danych i zmodyfikowanie go w celu spełnienia konkretnych potrzeb.

Jako przykład przyjrzyjmy się opakowaniu zbioru danych CIFAR-10 i jego działaniu. Rdzeń zbioru danych szkoleniowych jest zdefiniowany tutaj:

let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
  .lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
    return batches.lazy.map{
      makeBatch(samples: $0, mean: mean, standardDeviation: standardDeviation, device: device)
  }
}

Wynikiem funkcji loadCIFARTrainingFiles() jest tablica (data: [UInt8], label: Int32) krotek dla każdego obrazu w zbiorze danych szkoleniowych. Wartość ta jest następnie przekazywana do TrainingEpochs(samples:batchSize:entropy:) w celu utworzenia nieskończonej sekwencji epok z partiami batchSize . Możesz udostępnić własny generator liczb losowych w przypadkach, gdy chcesz zachować deterministyczne zachowanie wsadowe, ale domyślnie używany jest SystemRandomNumberGenerator .

Stamtąd leniwe mapy w partiach kończą się funkcją makeBatch(samples:mean:standardDeviation:device:) . Jest to funkcja niestandardowa, w której zlokalizowany jest rzeczywisty potok przetwarzania obrazu dla zbioru danych CIFAR-10, więc przyjrzyjmy się temu:

fileprivate func makeBatch<BatchSamples: Collection>(
  samples: BatchSamples, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?, device: Device
) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) {
  let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
  let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device)

  var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
  imageTensor /= 255.0
  if let mean = mean, let standardDeviation = standardDeviation {
    imageTensor = (imageTensor - mean) / standardDeviation
  }

  let labels = Tensor<Int32>(samples.map(\.label), on: device)
  return LabeledImage(data: imageTensor, label: labels)
}

Dwie linie tej funkcji łączą wszystkie bajty data z przychodzących próbek BatchSamples w Tensor<UInt8> , który pasuje do układu bajtów obrazów w surowym zestawie danych CIFAR-10. Następnie kolejność kanałów obrazu jest zmieniana tak, aby odpowiadała oczekiwaniom w naszych standardowych modelach klasyfikacji obrazów, a dane obrazu są ponownie rzutowane do Tensor<Float> w celu wykorzystania modelu.

Można udostępnić opcjonalne parametry normalizacji w celu dalszego dostosowania wartości kanałów obrazu, co jest procesem powszechnym podczas uczenia wielu modeli klasyfikacji obrazów. Parametr normalizacyjny Tensor s jest tworzony raz podczas inicjalizacji zestawu danych, a następnie przekazywany do makeBatch() w ramach optymalizacji, aby zapobiec wielokrotnemu tworzeniu małych tymczasowych tensorów o tych samych wartościach.

Na koniec etykiety całkowite są umieszczane w Tensor<Int32> , a para tensorów obraz/etykieta zwracana w LabeledImage . LabeledImage to specyficzny przypadek LabeledData , struktury zawierającej dane i etykiety zgodne z protokołem Collatable interfejsu API Eppch.

Aby uzyskać więcej przykładów interfejsu API Epochs w różnych typach zestawów danych, możesz sprawdzić inne opakowania zestawu danych w repozytorium modeli.