Suggerimenti per le prestazioni

Questo documento fornisce suggerimenti sulle prestazioni specifici di TensorFlow Datasets (TFDS). Si noti che TFDS fornisce set di dati come oggetti tf.data.Dataset , quindi si applicano ancora i consigli della guida tf.data .

Set di dati di riferimento

Usa tfds.benchmark(ds) per confrontare qualsiasi oggetto tf.data.Dataset .

Assicurati di indicare batch_size= per normalizzare i risultati (es. 100 iter/sec -> 3200 ex/sec). Funziona con qualsiasi iterabile (ad esempio tfds.benchmark(tfds.as_numpy(ds)) ).

ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)

Piccoli set di dati (meno di 1 GB)

Tutti i set di dati TFDS memorizzano i dati su disco nel formato TFRecord . Per piccoli set di dati (ad es. MNIST, CIFAR-10/-100), la lettura da .tfrecord può aggiungere un sovraccarico significativo.

Poiché tali set di dati si adattano alla memoria, è possibile migliorare significativamente le prestazioni memorizzando nella cache o precaricando il set di dati. Si noti che TFDS memorizza automaticamente nella cache piccoli set di dati (la sezione seguente contiene i dettagli).

Memorizzazione nella cache del set di dati

Ecco un esempio di una pipeline di dati che memorizza esplicitamente nella cache il set di dati dopo aver normalizzato le immagini.

def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.cast(image, tf.float32) / 255., label


ds, ds_info = tfds.load(
    'mnist',
    split='train',
    as_supervised=True,  # returns `(img, label)` instead of dict(image=, ...)
    with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)

Quando si esegue l'iterazione su questo set di dati, la seconda iterazione sarà molto più veloce della prima grazie alla memorizzazione nella cache.

Memorizzazione automatica nella cache

Per impostazione predefinita, i set di dati TFDS memorizzano automaticamente nella cache (con ds.cache() ) i seguenti vincoli:

  • La dimensione totale del set di dati (tutte le suddivisioni) è definita e < 250 MiB
  • shuffle_files è disabilitato o viene letto solo un singolo shard

È possibile disattivare la memorizzazione nella cache automatica passando try_autocaching=False a tfds.ReadConfig in tfds.load . Dai un'occhiata alla documentazione del catalogo del set di dati per vedere se un set di dati specifico utilizzerà la cache automatica.

Caricamento dei dati completi come un unico tensore

Se il tuo set di dati si adatta alla memoria, puoi anche caricare l'intero set di dati come un singolo array Tensor o NumPy. È possibile farlo impostando batch_size=-1 per raggruppare tutti gli esempi in un singolo tf.Tensor . Quindi usa tfds.as_numpy per la conversione da tf.Tensor a np.array .

(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
    'mnist',
    split=['train', 'test'],
    batch_size=-1,
    as_supervised=True,
))

Grandi set di dati

I set di dati di grandi dimensioni vengono partizionati (divisi in più file) e in genere non si adattano alla memoria, quindi non devono essere memorizzati nella cache.

Mescolare e allenare

Durante l'allenamento, è importante mescolare bene i dati: dati mescolati male possono comportare una minore precisione dell'allenamento.

Oltre a utilizzare ds.shuffle per mescolare i record, dovresti anche impostare shuffle_files=True per ottenere un buon comportamento di mescolamento per set di dati più grandi che sono partizionati in più file. In caso contrario, epochs leggerà i frammenti nello stesso ordine e quindi i dati non saranno veramente randomizzati.

ds = tfds.load('imagenet2012', split='train', shuffle_files=True)

Inoltre, quando shuffle_files=True , TFDS disabilita options.deterministic , che potrebbe dare un leggero aumento delle prestazioni. Per ottenere un miscuglio deterministico, è possibile disattivare questa funzione con tfds.ReadConfig : impostando read_config.shuffle_seed o sovrascrivendo read_config.options.deterministic .

Condivisione automatica dei dati tra lavoratori (TF)

Quando si esegue l'addestramento su più lavoratori, è possibile utilizzare l'argomento input_context di tfds.ReadConfig , in modo che ogni lavoratore leggerà un sottoinsieme di dati.

input_context = tf.distribute.InputContext(
    input_pipeline_id=1,  # Worker id
    num_input_pipelines=4,  # Total number of workers
)
read_config = tfds.ReadConfig(
    input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)

Questo è complementare all'API subsplit. Innanzitutto, viene applicata l'API subplit: train[:50%] viene convertito in un elenco di file da leggere. Quindi, un'operazione ds.shard() viene applicata a quei file. Ad esempio, quando si utilizza train[:50%] con num_input_pipelines=2 , ciascuno dei 2 lavoratori leggerà 1/4 dei dati.

Quando shuffle_files=True , i file vengono mischiati all'interno di un lavoratore, ma non tra i lavoratori. Ogni lavoratore leggerà lo stesso sottoinsieme di file tra epoche.

Sharding automatico dei dati tra i dipendenti (Jax)

Con Jax, puoi utilizzare l'API tfds.split_for_jax_process o tfds.even_splits per distribuire i tuoi dati tra i lavoratori. Consulta la guida dell'API divisa .

split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)

tfds.split_for_jax_process è un semplice alias per:

# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]

Decodifica delle immagini più veloce

Per impostazione predefinita, TFDS decodifica automaticamente le immagini. Tuttavia, ci sono casi in cui può essere più efficace saltare la decodifica dell'immagine con tfds.decode.SkipDecoding e applicare manualmente l'op tf.io.decode_image :

Il codice per entrambi gli esempi è disponibile nella guida alla decodifica .

Salta le funzioni non utilizzate

Se stai utilizzando solo un sottoinsieme delle funzionalità, è possibile ignorare completamente alcune funzionalità. Se il tuo set di dati ha molte funzionalità inutilizzate, non decodificarle può migliorare significativamente le prestazioni. Vedi https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features