Dokumen ini memberikan tips performa khusus TensorFlow Datasets (TFDS). Perhatikan bahwa TFDS menyediakan kumpulan data sebagai objek tf.data.Dataset
, sehingga saran dari panduan tf.data
tetap berlaku.
Kumpulan data tolok ukur
Gunakan tfds.benchmark(ds)
untuk melakukan benchmark pada objek tf.data.Dataset
.
Pastikan untuk menunjukkan batch_size=
untuk menormalkan hasil (misalnya 100 iter/detik -> 3200 ex/detik). Ini berfungsi dengan semua iterable (misalnya tfds.benchmark(tfds.as_numpy(ds))
).
ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)
Kumpulan data kecil (kurang dari 1 GB)
Semua kumpulan data TFDS menyimpan data pada disk dalam format TFRecord
. Untuk kumpulan data kecil (misalnya MNIST, CIFAR-10/-100), membaca dari .tfrecord
dapat menambah overhead yang signifikan.
Saat kumpulan data tersebut sesuai dengan memori, kinerja dapat ditingkatkan secara signifikan dengan melakukan cache atau memuat kumpulan data terlebih dahulu. Perhatikan bahwa TFDS secara otomatis menyimpan kumpulan data kecil dalam cache (bagian berikut berisi detailnya).
Menyimpan kumpulan data dalam cache
Berikut adalah contoh pipeline data yang secara eksplisit menyimpan dataset dalam cache setelah normalisasi gambar.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds, ds_info = tfds.load(
'mnist',
split='train',
as_supervised=True, # returns `(img, label)` instead of dict(image=, ...)
with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
Saat melakukan iterasi pada kumpulan data ini, iterasi kedua akan jauh lebih cepat daripada iterasi pertama berkat caching.
Caching otomatis
Secara default, TFDS melakukan cache otomatis (dengan ds.cache()
) kumpulan data yang memenuhi batasan berikut:
- Ukuran total kumpulan data (semua pemisahan) ditentukan dan <250 MiB
-
shuffle_files
dinonaktifkan, atau hanya satu pecahan yang dibaca
Anda dapat memilih keluar dari cache otomatis dengan meneruskan try_autocaching=False
ke tfds.ReadConfig
di tfds.load
. Lihat dokumentasi katalog kumpulan data untuk melihat apakah kumpulan data tertentu akan menggunakan cache otomatis.
Memuat data lengkap sebagai Tensor tunggal
Jika kumpulan data Anda sesuai dengan memori, Anda juga dapat memuat kumpulan data lengkap sebagai satu larik Tensor atau NumPy. Hal ini dapat dilakukan dengan menyetel batch_size=-1
ke kumpulan semua contoh dalam satu tf.Tensor
. Kemudian gunakan tfds.as_numpy
untuk konversi dari tf.Tensor
ke np.array
.
(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
'mnist',
split=['train', 'test'],
batch_size=-1,
as_supervised=True,
))
Kumpulan data besar
Kumpulan data berukuran besar akan dipecah (dipecah menjadi beberapa file) dan biasanya tidak muat di memori, sehingga tidak boleh di-cache.
Acak dan latih
Selama pelatihan, penting untuk mengacak data dengan baik - data yang diacak dengan buruk dapat mengakibatkan akurasi pelatihan yang lebih rendah.
Selain menggunakan ds.shuffle
untuk mengacak catatan, Anda juga harus mengatur shuffle_files=True
untuk mendapatkan perilaku pengacakan yang baik untuk kumpulan data yang lebih besar yang dibagi menjadi beberapa file. Jika tidak, zaman akan membaca pecahan dalam urutan yang sama, sehingga data tidak akan benar-benar diacak.
ds = tfds.load('imagenet2012', split='train', shuffle_files=True)
Selain itu, ketika shuffle_files=True
, TFDS menonaktifkan options.deterministic
, yang mungkin memberikan sedikit peningkatan kinerja. Untuk mendapatkan pengacakan deterministik, Anda dapat memilih untuk tidak ikut serta dalam fitur ini dengan tfds.ReadConfig
: baik dengan menyetel read_config.shuffle_seed
atau menimpa read_config.options.deterministic
.
Memecahkan data Anda secara otomatis ke seluruh pekerja (TF)
Saat melatih beberapa pekerja, Anda dapat menggunakan argumen input_context
dari tfds.ReadConfig
, sehingga setiap pekerja akan membaca subset data.
input_context = tf.distribute.InputContext(
input_pipeline_id=1, # Worker id
num_input_pipelines=4, # Total number of workers
)
read_config = tfds.ReadConfig(
input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)
Ini melengkapi API subsplit. Pertama, API subplit diterapkan: train[:50%]
diubah menjadi daftar file untuk dibaca. Kemudian, operasi ds.shard()
diterapkan pada file tersebut. Misalnya, saat menggunakan train[:50%]
dengan num_input_pipelines=2
, masing-masing dari 2 pekerja akan membaca 1/4 data.
Ketika shuffle_files=True
, file diacak dalam satu pekerja, namun tidak antar pekerja. Setiap pekerja akan membaca subset file yang sama antar zaman.
Memecahkan data Anda secara otomatis ke seluruh pekerja (Jax)
Dengan Jax, Anda dapat menggunakan API tfds.split_for_jax_process
atau tfds.even_splits
untuk mendistribusikan data Anda ke seluruh pekerja. Lihat panduan API terpisah .
split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
tfds.split_for_jax_process
adalah alias sederhana untuk:
# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]
Penguraian kode gambar lebih cepat
Secara default, TFDS secara otomatis menerjemahkan gambar. Namun, ada beberapa kasus di mana akan lebih baik jika melewatkan decoding gambar dengan tfds.decode.SkipDecoding
dan secara manual menerapkan operasi tf.io.decode_image
:
- Saat memfilter contoh (dengan
tf.data.Dataset.filter
), untuk mendekode gambar setelah contoh difilter. - Saat memotong gambar, untuk menggunakan operasi
tf.image.decode_and_crop_jpeg
yang menyatu.
Kode untuk kedua contoh tersedia di panduan dekode .
Lewati fitur yang tidak digunakan
Jika Anda hanya menggunakan sebagian fitur, beberapa fitur dapat dilewati sepenuhnya. Jika kumpulan data Anda memiliki banyak fitur yang tidak digunakan, tidak mendekode fitur tersebut dapat meningkatkan performa secara signifikan. Lihat https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features
tf.data menggunakan semua RAM saya!
Jika RAM Anda terbatas, atau jika Anda memuat banyak kumpulan data secara paralel saat menggunakan tf.data
, berikut beberapa opsi yang dapat membantu:
Ganti ukuran buffer
builder.as_dataset(
read_config=tfds.ReadConfig(
...
override_buffer_size=1024, # Save quite a bit of RAM.
),
...
)
Ini mengesampingkan buffer_size
yang diteruskan ke TFRecordDataset
(atau yang setara): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args
Gunakan tf.data.Dataset.with_options untuk menghentikan perilaku ajaib
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options
options = tf.data.Options()
# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False
data = data.with_options(options)