Set data

Dalam banyak model pembelajaran mesin, terutama untuk pembelajaran terawasi, kumpulan data merupakan bagian penting dari proses pelatihan. Swift untuk TensorFlow menyediakan pembungkus untuk beberapa kumpulan data umum dalam modul Kumpulan Data di repositori model . Wrapper ini memudahkan penggunaan kumpulan data umum dengan model berbasis Swift dan terintegrasi dengan baik dengan loop pelatihan umum Swift untuk TensorFlow.

Pembungkus kumpulan data disediakan

Berikut ini adalah wrapper dataset yang disediakan saat ini dalam repositori model:

Untuk menggunakan salah satu pembungkus kumpulan data ini dalam proyek Swift, tambahkan Datasets sebagai dependensi ke target Swift Anda dan impor modul:

import Datasets

Sebagian besar pembungkus kumpulan data dirancang untuk menghasilkan kumpulan data berlabel yang diacak secara acak. Misalnya, untuk menggunakan dataset CIFAR-10, Anda terlebih dahulu menginisialisasinya dengan ukuran batch yang diinginkan:

let dataset = CIFAR10(batchSize: 100)

Saat pertama kali digunakan, wrapper kumpulan data Swift untuk TensorFlow akan secara otomatis mengunduh kumpulan data asli untuk Anda, mengekstrak dan mengurai semua arsip yang relevan, lalu menyimpan kumpulan data yang diproses dalam direktori cache lokal pengguna. Penggunaan kumpulan data yang sama selanjutnya akan dimuat langsung dari cache lokal.

Untuk menyiapkan loop pelatihan manual yang melibatkan kumpulan data ini, Anda akan menggunakan sesuatu seperti berikut:

for (epoch, epochBatches) in dataset.training.prefix(100).enumerated() {
  Context.local.learningPhase = .training
  ...
  for batch in epochBatches {
    let (images, labels) = (batch.data, batch.label)
    ...
  }
}

Perintah di atas menyiapkan iterator melalui 100 epoch ( .prefix(100) ), dan mengembalikan indeks numerik epoch saat ini dan urutan yang dipetakan secara malas pada kumpulan acak yang membentuk epoch tersebut. Dalam setiap periode pelatihan, batch diiterasi dan diekstraksi untuk diproses. Dalam kasus pembungkus kumpulan data CIFAR10 , setiap batch adalah LabeledImage , yang menyediakan Tensor<Float> yang berisi semua gambar dari batch tersebut dan Tensor<Int32> dengan label yang cocok.

Dalam kasus CIFAR-10, seluruh kumpulan data berukuran kecil dan dapat dimuat ke dalam memori sekaligus, namun untuk kumpulan data lain yang lebih besar, kumpulan data dimuat dengan lambat dari disk dan diproses pada titik di mana setiap kumpulan diperoleh. Hal ini mencegah kehabisan memori dengan kumpulan data yang lebih besar.

API Epoch

Sebagian besar wrapper kumpulan data ini dibuat pada infrastruktur bersama yang kami sebut Epochs API . Epochs menyediakan komponen fleksibel yang dimaksudkan untuk mendukung berbagai jenis kumpulan data, mulai dari teks hingga gambar dan banyak lagi.

Jika Anda ingin membuat wrapper kumpulan data Swift Anda sendiri, kemungkinan besar Anda ingin menggunakan Epochs API untuk melakukannya. Namun, untuk kasus umum, seperti kumpulan data klasifikasi gambar, kami sangat menyarankan untuk memulai dari templat berdasarkan salah satu pembungkus kumpulan data yang ada dan memodifikasinya untuk memenuhi kebutuhan spesifik Anda.

Sebagai contoh, mari kita periksa pembungkus kumpulan data CIFAR-10 dan cara kerjanya. Inti dari kumpulan data pelatihan didefinisikan di sini:

let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory)
training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy)
  .lazy.map { (batches: Batches) -> LazyMapSequence<Batches, LabeledImage> in
    return batches.lazy.map{
      makeBatch(samples: $0, mean: mean, standardDeviation: standardDeviation, device: device)
  }
}

Hasil dari fungsi loadCIFARTrainingFiles() adalah array tupel (data: [UInt8], label: Int32) untuk setiap gambar dalam dataset pelatihan. Ini kemudian diberikan kepada TrainingEpochs(samples:batchSize:entropy:) untuk membuat urutan zaman yang tak terbatas dengan kumpulan batchSize . Anda dapat menyediakan generator nomor acak Anda sendiri jika Anda menginginkan perilaku pengelompokan deterministik, tetapi secara default SystemRandomNumberGenerator digunakan.

Dari sana, peta malas pada batch berujung pada fungsi makeBatch(samples:mean:standardDeviation:device:) . Ini adalah fungsi khusus tempat alur pemrosesan gambar sebenarnya untuk kumpulan data CIFAR-10 berada, jadi mari kita lihat:

fileprivate func makeBatch<BatchSamples: Collection>(
  samples: BatchSamples, mean: Tensor<Float>?, standardDeviation: Tensor<Float>?, device: Device
) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) {
  let bytes = samples.lazy.map(\.data).reduce(into: [], +=)
  let images = Tensor<UInt8>(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device)

  var imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))
  imageTensor /= 255.0
  if let mean = mean, let standardDeviation = standardDeviation {
    imageTensor = (imageTensor - mean) / standardDeviation
  }

  let labels = Tensor<Int32>(samples.map(\.label), on: device)
  return LabeledImage(data: imageTensor, label: labels)
}

Dua baris fungsi ini menggabungkan semua byte data dari BatchSamples yang masuk menjadi Tensor<UInt8> yang cocok dengan tata letak byte gambar dalam kumpulan data CIFAR-10 mentah. Selanjutnya, saluran gambar disusun ulang agar sesuai dengan yang diharapkan dalam model klasifikasi gambar standar kami dan data gambar diubah menjadi Tensor<Float> untuk konsumsi model.

Parameter normalisasi opsional dapat diberikan untuk menyesuaikan lebih lanjut nilai saluran gambar, sebuah proses yang umum terjadi saat melatih banyak model klasifikasi gambar. Parameter normalisasi Tensor s dibuat satu kali pada inisialisasi kumpulan data dan kemudian diteruskan ke makeBatch() sebagai pengoptimalan untuk mencegah pembuatan berulang tensor sementara kecil dengan nilai yang sama.

Terakhir, label bilangan bulat ditempatkan di Tensor<Int32> dan pasangan tensor gambar/label dikembalikan dalam LabeledImage . LabeledImage adalah kasus khusus LabeledData , sebuah struct dengan data dan label yang sesuai dengan protokol Collatable API Eppch.

Untuk contoh lebih lanjut tentang Epochs API dalam tipe dataset yang berbeda, Anda dapat memeriksa wrapper dataset lainnya dalam repositori model.