Este documento fornece dicas de desempenho específicas do TensorFlow Datasets (TFDS). Observe que o TFDS fornece conjuntos de dados como objetos tf.data.Dataset
, portanto, o conselho do guia tf.data
ainda se aplica.
Conjuntos de dados de referência
Use tfds.benchmark(ds)
para avaliar qualquer objeto tf.data.Dataset
.
Certifique-se de indicar batch_size=
para normalizar os resultados (por exemplo, 100 iter/seg -> 3200 ex/seg). Isso funciona com qualquer iterável (por exemplo tfds.benchmark(tfds.as_numpy(ds))
).
ds = tfds.load('mnist', split='train').batch(32).prefetch()
# Display some benchmark statistics
tfds.benchmark(ds, batch_size=32)
# Second iteration is much faster, due to auto-caching
tfds.benchmark(ds, batch_size=32)
Conjuntos de dados pequenos (menos de 1 GB)
Todos os conjuntos de dados TFDS armazenam os dados em disco no formato TFRecord
. Para conjuntos de dados pequenos (por exemplo, MNIST, CIFAR-10/-100), a leitura de .tfrecord
pode adicionar sobrecarga significativa.
À medida que esses conjuntos de dados cabem na memória, é possível melhorar significativamente o desempenho armazenando em cache ou pré-carregando o conjunto de dados. Observe que o TFDS armazena automaticamente em cache pequenos conjuntos de dados (a seção a seguir contém os detalhes).
Armazenando o conjunto de dados em cache
Aqui está um exemplo de pipeline de dados que armazena explicitamente o conjunto de dados em cache após normalizar as imagens.
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds, ds_info = tfds.load(
'mnist',
split='train',
as_supervised=True, # returns `(img, label)` instead of dict(image=, ...)
with_info=True,
)
# Applying normalization before `ds.cache()` to re-use it.
# Note: Random transformations (e.g. images augmentations) should be applied
# after both `ds.cache()` (to avoid caching randomness) and `ds.batch()` (for
# vectorization [1]).
ds = ds.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.cache()
# For true randomness, we set the shuffle buffer to the full dataset size.
ds = ds.shuffle(ds_info.splits['train'].num_examples)
# Batch after shuffling to get unique batches at each epoch.
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
Ao iterar neste conjunto de dados, a segunda iteração será muito mais rápida que a primeira graças ao cache.
Cache automático
Por padrão, o TFDS armazena em cache automático (com ds.cache()
) conjuntos de dados que satisfazem as seguintes restrições:
- O tamanho total do conjunto de dados (todas as divisões) é definido e <250 MiB
-
shuffle_files
está desativado ou apenas um único fragmento é lido
É possível cancelar o cache automático passando try_autocaching=False
para tfds.ReadConfig
em tfds.load
. Dê uma olhada na documentação do catálogo do conjunto de dados para ver se um conjunto de dados específico usará o cache automático.
Carregando os dados completos como um único Tensor
Se o seu conjunto de dados couber na memória, você também pode carregar o conjunto de dados completo como um único array Tensor ou NumPy. É possível fazer isso definindo batch_size=-1
para agrupar todos os exemplos em um único tf.Tensor
. Em seguida, use tfds.as_numpy
para a conversão de tf.Tensor
para np.array
.
(img_train, label_train), (img_test, label_test) = tfds.as_numpy(tfds.load(
'mnist',
split=['train', 'test'],
batch_size=-1,
as_supervised=True,
))
Grandes conjuntos de dados
Grandes conjuntos de dados são fragmentados (divididos em vários arquivos) e normalmente não cabem na memória, portanto, não devem ser armazenados em cache.
Embaralhamento e treinamento
Durante o treinamento, é importante embaralhar bem os dados – dados mal embaralhados podem resultar em menor precisão do treinamento.
Além de usar ds.shuffle
para embaralhar registros, você também deve definir shuffle_files=True
para obter um bom comportamento de embaralhamento para conjuntos de dados maiores que são fragmentados em vários arquivos. Caso contrário, as épocas lerão os fragmentos na mesma ordem e, portanto, os dados não serão verdadeiramente randomizados.
ds = tfds.load('imagenet2012', split='train', shuffle_files=True)
Além disso, quando shuffle_files=True
, o TFDS desativa options.deterministic
, o que pode proporcionar um ligeiro aumento de desempenho. Para obter embaralhamento determinístico, é possível desativar esse recurso com tfds.ReadConfig
: configurando read_config.shuffle_seed
ou substituindo read_config.options.deterministic
.
Fragmente automaticamente seus dados entre trabalhadores (TF)
Ao treinar vários trabalhadores, você pode usar o argumento input_context
de tfds.ReadConfig
, para que cada trabalhador leia um subconjunto dos dados.
input_context = tf.distribute.InputContext(
input_pipeline_id=1, # Worker id
num_input_pipelines=4, # Total number of workers
)
read_config = tfds.ReadConfig(
input_context=input_context,
)
ds = tfds.load('dataset', split='train', read_config=read_config)
Isso é complementar à API subsplit. Primeiro, a API subplit é aplicada: train[:50%]
é convertido em uma lista de arquivos para leitura. Em seguida, uma operação ds.shard()
é aplicada a esses arquivos. Por exemplo, ao usar train[:50%]
com num_input_pipelines=2
, cada um dos 2 trabalhadores lerá 1/4 dos dados.
Quando shuffle_files=True
, os arquivos são embaralhados dentro de um trabalhador, mas não entre trabalhadores. Cada trabalhador lerá o mesmo subconjunto de arquivos entre épocas.
Fragmente automaticamente seus dados entre trabalhadores (Jax)
Com Jax, você pode usar a API tfds.split_for_jax_process
ou tfds.even_splits
para distribuir seus dados entre trabalhadores. Consulte o guia da API dividida .
split = tfds.split_for_jax_process('train', drop_remainder=True)
ds = tfds.load('my_dataset', split=split)
tfds.split_for_jax_process
é um alias simples para:
# The current `process_index` loads only `1 / process_count` of the data.
splits = tfds.even_splits('train', n=jax.process_count(), drop_remainder=True)
split = splits[jax.process_index()]
Decodificação de imagem mais rápida
Por padrão, o TFDS decodifica imagens automaticamente. No entanto, há casos em que pode ser mais eficiente pular a decodificação da imagem com tfds.decode.SkipDecoding
e aplicar manualmente a operação tf.io.decode_image
:
- Ao filtrar exemplos (com
tf.data.Dataset.filter
), para decodificar imagens após a filtragem dos exemplos. - Ao cortar imagens, use a operação fundida
tf.image.decode_and_crop_jpeg
.
O código para ambos os exemplos está disponível no guia de decodificação .
Ignorar recursos não utilizados
Se você estiver usando apenas um subconjunto de recursos, é possível ignorar completamente alguns recursos. Se o seu conjunto de dados tiver muitos recursos não utilizados, não decodificá-los pode melhorar significativamente o desempenho. Consulte https://www.tensorflow.org/datasets/decode#only_decode_a_sub-set_of_the_features
tf.data usa toda a minha RAM!
Se você tiver memória RAM limitada ou estiver carregando muitos conjuntos de dados em paralelo ao usar tf.data
, aqui estão algumas opções que podem ajudar:
Substituir tamanho do buffer
builder.as_dataset(
read_config=tfds.ReadConfig(
...
override_buffer_size=1024, # Save quite a bit of RAM.
),
...
)
Isso substitui o buffer_size
passado para TFRecordDataset
(ou equivalente): https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset#args
Use tf.data.Dataset.with_options para interromper comportamentos mágicos
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#with_options
options = tf.data.Options()
# Stop magic stuff that eats up RAM:
options.autotune.enabled = False
options.experimental_distribute.auto_shard_policy = (
tf.data.experimental.AutoShardPolicy.OFF)
options.experimental_optimization.inject_prefetch = False
data = data.with_options(options)