Tips performa

Dokumen ini memberikan tips performa khusus TensorFlow Datasets (TFDS). Perhatikan bahwa TFDS menyediakan kumpulan data sebagai objek tf.data.Dataset , sehingga saran dari panduan tf.data tetap berlaku.

Kumpulan data tolok ukur

Gunakan tfds.benchmark(ds) untuk melakukan benchmark pada objek tf.data.Dataset .

Pastikan untuk menunjukkan batch_size= untuk menormalkan hasil (misalnya 100 iter/detik -> 3200 ex/detik). Ini berfungsi dengan semua iterable (misalnya tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Kumpulan data kecil (kurang dari 1 GB)

Semua kumpulan data TFDS menyimpan data pada disk dalam format TFRecord . Untuk kumpulan data kecil (misalnya MNIST, CIFAR-10/-100), membaca dari .tfrecord dapat menambah overhead yang signifikan.

Saat kumpulan data tersebut sesuai dengan memori, kinerja dapat ditingkatkan secara signifikan dengan melakukan cache atau memuat kumpulan data terlebih dahulu. Perhatikan bahwa TFDS secara otomatis menyimpan kumpulan data kecil dalam cache (bagian berikut berisi detailnya).

Menyimpan kumpulan data dalam cache

Berikut adalah contoh pipeline data yang secara eksplisit menyimpan dataset dalam cache setelah normalisasi gambar.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Saat melakukan iterasi pada kumpulan data ini, iterasi kedua akan jauh lebih cepat daripada iterasi pertama berkat caching.

Caching otomatis

Secara default, TFDS melakukan cache otomatis (dengan ds.cache() ) kumpulan data yang memenuhi batasan berikut:

  • Ukuran total kumpulan data (semua pemisahan) ditentukan dan <250 MiB
  • shuffle_files dinonaktifkan, atau hanya satu pecahan yang dibaca

Anda dapat memilih keluar dari cache otomatis dengan meneruskan try_autocaching=False ke tfds.ReadConfig di tfds.load . Lihat dokumentasi katalog kumpulan data untuk melihat apakah kumpulan data tertentu akan menggunakan cache otomatis.

Memuat data lengkap sebagai Tensor tunggal

Jika kumpulan data Anda sesuai dengan memori, Anda juga dapat memuat kumpulan data lengkap sebagai satu larik Tensor atau NumPy. Hal ini dapat dilakukan dengan menyetel batch_size=-1 ke kumpulan semua contoh dalam satu tf.Tensor . Kemudian gunakan tfds.as_numpy untuk konversi dari tf.Tensor ke np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Kumpulan data besar

Kumpulan data berukuran besar akan dipecah (dipecah menjadi beberapa file) dan biasanya tidak muat di memori, sehingga tidak boleh di-cache.

Acak dan latih

Selama pelatihan, penting untuk mengacak data dengan baik - data yang diacak dengan buruk dapat mengakibatkan akurasi pelatihan yang lebih rendah.

Selain menggunakan ds.shuffle untuk mengacak catatan, Anda juga harus mengatur shuffle_files=True untuk mendapatkan perilaku pengacakan yang baik untuk kumpulan data yang lebih besar yang dibagi menjadi beberapa file. Jika tidak, zaman akan membaca pecahan dalam urutan yang sama, sehingga data tidak akan benar-benar diacak.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Selain itu, ketika shuffle_files=True , TFDS menonaktifkan options.deterministic , yang mungkin memberikan sedikit peningkatan kinerja. Untuk mendapatkan pengacakan deterministik, Anda dapat memilih untuk tidak ikut serta dalam fitur ini dengan tfds.ReadConfig : baik dengan menyetel read_config.shuffle_seed atau menimpa read_config.options.deterministic .

Memecahkan data Anda secara otomatis ke seluruh pekerja (TF)

Saat melatih beberapa pekerja, Anda dapat menggunakan argumen input_context dari tfds.ReadConfig , sehingga setiap pekerja akan membaca subset data.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Ini melengkapi API subsplit. Pertama, API subplit diterapkan: train[:50%] diubah menjadi daftar file untuk dibaca. Kemudian, operasi ds.shard() diterapkan pada file tersebut. Misalnya, saat menggunakan train[:50%] dengan num_input_pipelines=2 , masing-masing dari 2 pekerja akan membaca 1/4 data.

Ketika shuffle_files=True , file diacak dalam satu pekerja, namun tidak antar pekerja. Setiap pekerja akan membaca subset file yang sama antar zaman.

Memecahkan data Anda secara otomatis ke seluruh pekerja (Jax)

Dengan Jax, Anda dapat menggunakan API tfds.split_for_jax_process atau tfds.even_splits untuk mendistribusikan data Anda ke seluruh pekerja. Lihat panduan API terpisah .

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process adalah alias sederhana untuk:

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

Penguraian kode gambar lebih cepat

Secara default, TFDS secara otomatis menerjemahkan gambar. Namun, ada beberapa kasus di mana akan lebih baik jika melewatkan decoding gambar dengan tfds.decode.SkipDecoding dan secara manual menerapkan operasi tf.io.decode_image :

Kode untuk kedua contoh tersedia di panduan dekode .

Lewati fitur yang tidak digunakan

Jika Anda hanya menggunakan sebagian fitur, beberapa fitur dapat dilewati sepenuhnya. Jika kumpulan data Anda memiliki banyak fitur yang tidak digunakan, tidak mendekode fitur tersebut dapat meningkatkan performa secara signifikan. Lihat https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features

tf.data menggunakan semua RAM saya!

Jika RAM Anda terbatas, atau jika Anda memuat banyak kumpulan data secara paralel saat menggunakan tf.data , berikut beberapa opsi yang dapat membantu:

Ganti ukuran buffer

builder.as_dataset(
  read_config=tfds.ReadConfig(
    ...
    override_buffer_size=1024,  # Save quite a bit of RAM.
  ),
  ...
)

Ini mengesampingkan buffer_size yang diteruskan ke TFRecordDataset (atau yang setara): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args

Gunakan tf.data.Dataset.with_options untuk menghentikan perilaku ajaib

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options

options = tf.data.Options()

# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
  tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False

data = data.with_options(options)