Suggerimenti per il rendimento

Questo documento fornisce suggerimenti sulle prestazioni specifici di TensorFlow Datasets (TFDS). Tieni presente che TFDS fornisce set di dati come oggetti tf.data.Dataset , quindi i consigli della guida tf.data sono ancora validi.

Set di dati di riferimento

Utilizza tfds.benchmark(ds) per confrontare qualsiasi oggetto tf.data.Dataset .

Assicurati di indicare batch_size= per normalizzare i risultati (ad esempio 100 iter/sec -> 3200 ex/sec). Funziona con qualsiasi iterabile (ad esempio tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Piccoli set di dati (meno di 1 GB)

Tutti i set di dati TFDS memorizzano i dati su disco nel formato TFRecord . Per set di dati di piccole dimensioni (ad esempio MNIST, CIFAR-10/-100), la lettura da .tfrecord può aggiungere un sovraccarico significativo.

Poiché questi set di dati si adattano alla memoria, è possibile migliorare significativamente le prestazioni memorizzando nella cache o precaricando il set di dati. Tieni presente che TFDS memorizza automaticamente nella cache piccoli set di dati (la sezione seguente contiene i dettagli).

Memorizzazione nella cache del set di dati

Ecco un esempio di una pipeline di dati che memorizza esplicitamente nella cache il set di dati dopo aver normalizzato le immagini.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Quando si esegue l'iterazione su questo set di dati, la seconda iterazione sarà molto più veloce della prima grazie alla memorizzazione nella cache.

Memorizzazione nella cache automatica

Per impostazione predefinita, TFDS memorizza automaticamente nella cache (con ds.cache() ) i set di dati che soddisfano i seguenti vincoli:

  • La dimensione totale del set di dati (tutte le suddivisioni) è definita e < 250 MiB
  • shuffle_files è disabilitato o viene letto solo un singolo frammento

È possibile disattivare la memorizzazione nella cache automatica passando try_autocaching=False a tfds.ReadConfig in tfds.load . Dai un'occhiata alla documentazione del catalogo del set di dati per vedere se un set di dati specifico utilizzerà la cache automatica.

Caricamento dei dati completi come un singolo tensore

Se il set di dati rientra nella memoria, puoi anche caricare l'intero set di dati come un singolo array Tensor o NumPy. È possibile farlo impostando batch_size=-1 per raggruppare tutti gli esempi in un unico tf.Tensor . Quindi utilizzare tfds.as_numpy per la conversione da tf.Tensor a np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Set di dati di grandi dimensioni

I set di dati di grandi dimensioni vengono suddivisi in partizioni (divisi in più file) e in genere non entrano nella memoria, quindi non devono essere memorizzati nella cache.

Shuffle e allenamento

Durante l'allenamento, è importante mescolare bene i dati: i dati mescolati in modo inadeguato possono comportare una minore precisione dell'allenamento.

Oltre a utilizzare ds.shuffle per mescolare i record, dovresti anche impostare shuffle_files=True per ottenere un buon comportamento di mescolamento per set di dati più grandi suddivisi in più file. Altrimenti, le epoche leggeranno i frammenti nello stesso ordine e quindi i dati non saranno veramente randomizzati.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Inoltre, quando shuffle_files=True , TFDS disabilita options.deterministic , il che può fornire un leggero aumento delle prestazioni. Per ottenere un mescolamento deterministico, è possibile disattivare questa funzionalità con tfds.ReadConfig : impostando read_config.shuffle_seed o sovrascrivendo read_config.options.deterministic .

Suddividi automaticamente i tuoi dati tra i lavoratori (TF)

Durante l'addestramento su più lavoratori, puoi utilizzare l'argomento input_context di tfds.ReadConfig , in modo che ogni lavoratore leggerà un sottoinsieme di dati.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Questo è complementare all'API subsplit. Per prima cosa viene applicata l'API subplit: train[:50%] viene convertito in un elenco di file da leggere. Quindi, a tali file viene applicata un'operazione ds.shard() . Ad esempio, quando si utilizza train[:50%] con num_input_pipelines=2 , ciascuno dei 2 lavoratori leggerà 1/4 dei dati.

Quando shuffle_files=True , i file vengono mescolati all'interno di un lavoratore, ma non tra lavoratori. Ogni lavoratore leggerà lo stesso sottoinsieme di file tra le epoche.

Suddividi automaticamente i tuoi dati tra i lavoratori (Jax)

Con Jax, puoi utilizzare l'API tfds.split_for_jax_process o tfds.even_splits per distribuire i tuoi dati tra i lavoratori. Consulta la guida all'API divisa .

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process è un semplice alias per:

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

Decodifica delle immagini più veloce

Per impostazione predefinita, TFDS decodifica automaticamente le immagini. Tuttavia, ci sono casi in cui può essere più efficace saltare la decodifica dell'immagine con tfds.decode.SkipDecoding e applicare manualmente l'operazione tf.io.decode_image :

Il codice per entrambi gli esempi è disponibile nella guida alla decodifica .

Salta le funzionalità non utilizzate

Se utilizzi solo un sottoinsieme delle funzionalità, è possibile ignorarne completamente alcune. Se il tuo set di dati ha molte funzionalità inutilizzate, non decodificarle può migliorare significativamente le prestazioni. Vedi https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features

tf.data utilizza tutta la mia RAM!

Se hai una RAM limitata o se stai caricando molti set di dati in parallelo mentre usi tf.data , ecco alcune opzioni che possono aiutarti:

Sostituisci la dimensione del buffer

builder.as_dataset(
  read_config=tfds.ReadConfig(
    ...
    override_buffer_size=1024,  # Save quite a bit of RAM.
  ),
  ...
)

Ciò sovrascrive buffer_size passato a TFRecordDataset (o equivalente): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args

Utilizza tf.data.Dataset.with_options per interrompere comportamenti magici

https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options

options = tf.data.Options()

# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
  tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False

data = data.with_options(options)