TFDS e determinismo

Ver no TensorFlow.org Executar no Google Colab Ver no GitHub Baixar caderno

Este documento explica:

  • As garantias TFDS sobre determinismo
  • Em que ordem o TFDS lê os exemplos
  • Várias advertências e pegadinhas

Configurar

Conjuntos de dados

É necessário algum contexto para entender como o TFDS lê os dados.

Durante a geração, TFDS escrever os dados originais em padronizados .tfrecord arquivos. Para grandes conjuntos de dados, vários .tfrecord arquivos são criados, cada um contendo vários exemplos. Chamamos cada .tfrecord apresentar um caco.

Este guia usa imagenet que tem 1024 fragmentos:

import re
import tensorflow_datasets as tfds

imagenet = tfds.builder('imagenet2012')

num_shards = imagenet.info.splits['train'].num_shards
num_examples = imagenet.info.splits['train'].num_examples
print(f'imagenet has {num_shards} shards ({num_examples} examples)')
imagenet has 1024 shards (1281167 examples)

Encontrar os IDs de exemplos do conjunto de dados

Você pode pular para a seção seguinte se quiser apenas saber sobre determinismo.

Cada exemplo de conjunto de dados é unicamente identificado por um id (por exemplo, 'imagenet2012-train.tfrecord-01023-of-01024__32' ). Você pode recuperar este id passando read_config.add_tfds_id = True que irá adicionar um 'tfds_id' chave na dict do tf.data.Dataset .

Neste tutorial, definimos um pequeno utilitário que imprimirá os ids de exemplo do conjunto de dados (convertidos em inteiros para serem mais legíveis):

Determinismo ao ler

Esta seção explica garantia deterministim de tfds.load .

Com shuffle_files=False (padrão)

Por TFDS padrão originar os exemplos deterministamente ( shuffle_files=False )

# Same as: imagenet.as_dataset(split='train').take(20)
print_ex_ids(imagenet, split='train', take=20)
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]

Para o desempenho, TFDS ler vários fragmentos, ao mesmo tempo usando tf.data.Dataset.interleave . Vemos neste exemplo que TFDS mudar para caco 2 depois de ler 16 exemplos ( ..., 14, 15, 1251, 1252, ... ). Mais sobre intercalar abaixo.

Da mesma forma, a API de subsplit também é determinística:

print_ex_ids(imagenet, split='train[67%:84%]', take=20)
print_ex_ids(imagenet, split='train[67%:84%]', take=20)
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]
[858382, 858383, 858384, 858385, 858386, 858387, 858388, 858389, 858390, 858391, 858392, 858393, 858394, 858395, 858396, 858397, 859533, 859534, 859535, 859536]

Se você está treinando para mais de uma época, a configuração acima não é recomendada como todas as épocas lerá os cacos na mesma ordem (para aleatoriedade é limitado aos ds = ds.shuffle(buffer) tamanho do buffer).

Com shuffle_files=True

Com shuffle_files=True , cacos são embaralhadas para cada época, de modo que a leitura não é mais determinista.

print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
print_ex_ids(imagenet, split='train', shuffle_files=True, take=20)
[568017, 329050, 329051, 329052, 329053, 329054, 329056, 329055, 568019, 568020, 568021, 568022, 568023, 568018, 568025, 568024, 568026, 568028, 568030, 568031]
[43790, 43791, 43792, 43793, 43796, 43794, 43797, 43798, 43795, 43799, 43800, 43801, 43802, 43803, 43804, 43805, 43806, 43807, 43809, 43810]

Veja a receita abaixo para obter uma ordem aleatória de arquivos determinística.

Advertência sobre determinismo: intercalar argumentos

Alterar read_config.interleave_cycle_length , read_config.interleave_block_length irá alterar a ordem exemplos.

TFDS depende tf.data.Dataset.interleave para carregar apenas alguns fragmentos de uma vez, melhorar o desempenho e reduzir o uso de memória.

O pedido de exemplo só tem garantia de ser o mesmo para um valor fixo de argumentos de intercalação. Veja doc intercalam para entender o que cycle_length e block_length correspondem também.

  • cycle_length=16 , block_length=16 (por defeito, o mesmo que acima):
print_ex_ids(imagenet, split='train', take=20)
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254]
  • cycle_length=3 , block_length=2 :
read_config = tfds.ReadConfig(
    interleave_cycle_length=3,
    interleave_block_length=2,
)
print_ex_ids(imagenet, split='train', read_config=read_config, take=20)
[0, 1, 1251, 1252, 2502, 2503, 2, 3, 1253, 1254, 2504, 2505, 4, 5, 1255, 1256, 2506, 2507, 6, 7]

No segundo exemplo, vemos que o conjunto de dados de leitura 2 ( block_length=2 ) Exemplos em um fragmento, em seguida mudar para o próximo fragmento. Cada 2 * 3 ( cycle_length=3 ) exemplos, que passa de volta para o primeiro fragmento ( shard0-ex0, shard0-ex1, shard1-ex0, shard1-ex1, shard2-ex0, shard2-ex1, shard0-ex2, shard0-ex3, shard1-ex2, shard1-ex3, shard2-ex2,... ).

Subsplit e exemplo de pedido

Cada exemplo tem um ID de 0, 1, ..., num_examples-1 . A API subsplit selecionar uma fatia de exemplos (por exemplo train[:x] selecione 0, 1, ..., x-1 ).

No entanto, dentro da sub-divisão, os exemplos não são lidos em ordem crescente de id (devido a fragmentos e intercalação).

Mais especificamente, ds.take(x) e split='train[:x]' não são equivalentes!

Isso pode ser visto facilmente no exemplo de intercalação acima, onde exemplos vêm de diferentes fragmentos.

print_ex_ids(imagenet, split='train', take=25)  # tfds.load(..., split='train').take(25)
print_ex_ids(imagenet, split='train[:25]', take=-1)  # tfds.load(..., split='train[:25]')
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1251, 1252, 1253, 1254, 1255, 1256, 1257, 1258, 1259]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]

Após os exemplos 16 (block_length), .take(25) muda para o próximo fragmento enquanto train[:25] continuar a ler na exemplos do primeiro fragmento.

Receitas

Obter embaralhamento determinístico de arquivos

Existem 2 maneiras de ter embaralhamento determinístico:

  1. Definir o shuffle_seed . Nota: Isso requer a mudança da semente em cada época, caso contrário, os fragmentos serão lidos na mesma ordem entre as épocas.
read_config = tfds.ReadConfig(
    shuffle_seed=32,
)

# Deterministic order, different from the default shuffle_files=False above
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
print_ex_ids(imagenet, split='train', shuffle_files=True, read_config=read_config, take=22)
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
[176411, 176412, 176413, 176414, 176415, 176416, 176417, 176418, 176419, 176420, 176421, 176422, 176423, 176424, 176425, 176426, 710647, 710648, 710649, 710650, 710651, 710652]
  1. Usando experimental_interleave_sort_fn : Isto dá total controle sobre quais cacos são lidos e em que ordem, em vez de depender ds.shuffle ordem.
def _reverse_order(file_instructions):
  return list(reversed(file_instructions))

read_config = tfds.ReadConfig(
    experimental_interleave_sort_fn=_reverse_order,
)

# Last shard (01023-of-01024) is read first
print_ex_ids(imagenet, split='train', read_config=read_config, take=5)
[1279916, 1279917, 1279918, 1279919, 1279920]

Obtenha pipeline preemptivo determinístico

Este é mais complicado. Não existe uma solução fácil e satisfatória.

  1. Sem ds.shuffle e com baralhar determinista, em teoria, deveria ser possível contar os exemplos que tenham sido lidos e deduzir que exemplos foram lidas dentro em cada fragmento (como uma função de cycle_length , block_length e ordem caco). Em seguida, o skip , take por cada fragmento podia ser injectado através de experimental_interleave_sort_fn .

  2. Com ds.shuffle é provável impossível sem reproduzir o gasoduto de treinamento completo. Seria necessário salvar o ds.shuffle estado-tampão para deduzir que exemplos foram lidas. Os exemplos podem ser não-contínuos (por exemplo, shard5_ex2 , shard5_ex4 ler mas não shard5_ex3 ).

  3. Com ds.shuffle , uma maneira seria para salvar todos os shards_ids / example_ids Read (deduzidas tfds_id ), em seguida, deduzindo as instruções do arquivo a partir daí.

O caso mais simples para 1. é ter .skip(x).take(y) jogo train[x:x+y] jogo. Isso requer:

  • Conjunto cycle_length=1 (assim fragmentos são lidos sequencialmente)
  • Definir shuffle_files=False
  • Não use ds.shuffle

Ele só deve ser usado em um grande conjunto de dados em que o treinamento dura apenas 1 época. Os exemplos seriam lidos na ordem aleatória padrão.

read_config = tfds.ReadConfig(
    interleave_cycle_length=1,  # Read shards sequentially
)

print_ex_ids(imagenet, split='train', read_config=read_config, skip=40, take=22)
# If the job get pre-empted, using the subsplit API will skip at most `len(shard0)`
print_ex_ids(imagenet, split='train[40:]', read_config=read_config, take=22)
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61]

Descubra quais fragmentos / exemplos são lidos para uma determinada sub-divisão

Com a tfds.core.DatasetInfo , você tem acesso direto às instruções de leitura.

imagenet.info.splits['train[44%:45%]'].file_instructions
[FileInstruction(filename='imagenet2012-train.tfrecord-00450-of-01024', skip=700, take=-1, num_examples=551),
 FileInstruction(filename='imagenet2012-train.tfrecord-00451-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00452-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00453-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00454-of-01024', skip=0, take=-1, num_examples=1252),
 FileInstruction(filename='imagenet2012-train.tfrecord-00455-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00456-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00457-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00458-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00459-of-01024', skip=0, take=-1, num_examples=1251),
 FileInstruction(filename='imagenet2012-train.tfrecord-00460-of-01024', skip=0, take=1001, num_examples=1001)]