Pomoc chronić Wielkiej Rafy Koralowej z TensorFlow na Kaggle Dołącz Wyzwanie

Punkty kontrolne szkolenia

Zobacz na TensorFlow.org Wyświetl źródło na GitHub Pobierz notatnik

Wyrażenie „Zapisywanie modelu TensorFlow” zazwyczaj oznacza jedną z dwóch rzeczy:

  1. Punkty kontrolne, OR
  2. Zapisany Model.

Punkty kontrolne uchwycić dokładnej wartości wszystkich parametrów ( tf.Variable obiektów) używanych przez model. Punkty kontrolne nie zawierają żadnego opisu obliczeń zdefiniowanych przez model i dlatego są zwykle przydatne tylko wtedy, gdy dostępny jest kod źródłowy, który będzie używał zapisanych wartości parametrów.

Z drugiej strony format SavedModel zawiera zserializowany opis obliczeń zdefiniowanych przez model oprócz wartości parametrów (punkt kontrolny). Modele w tym formacie są niezależne od kodu źródłowego, który utworzył model. Dzięki temu nadają się do wdrożenia za pośrednictwem TensorFlow Serving, TensorFlow Lite, TensorFlow.js lub programów w innych językach programowania (C, C++, Java, Go, Rust, C# itp. API TensorFlow).

Ten przewodnik obejmuje interfejsy API do zapisywania i odczytywania punktów kontrolnych.

Ustawiać

import tensorflow as tf
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)
net = Net()

Zapisywanie z tf.keras API szkoleniowych

Patrz tf.keras kierować na zapisywanie i odtwarzanie.

tf.keras.Model.save_weights ratuje punkt kontrolny TensorFlow.

net.save_weights('easy_checkpoint')

Pisanie punktów kontrolnych

Uporczywy stan modelu TensorFlow jest przechowywany w tf.Variable obiektów. Mogą być wykonana bezpośrednio, ale często są tworzone poprzez API wysokiego poziomu, takich jak tf.keras.layers lub tf.keras.Model .

Najłatwiejszym sposobem zarządzania zmiennymi jest dołączanie ich do obiektów Pythona, a następnie odwoływanie się do tych obiektów.

Podklasy tf.train.Checkpoint , tf.keras.layers.Layer i tf.keras.Model automatycznie śledzić zmienne przypisane do ich atrybutów. Poniższy przykład konstruuje prosty model liniowy, a następnie zapisuje punkty kontrolne, które zawierają wartości dla wszystkich zmiennych modelu.

Możesz łatwo zapisać modela kontrolny z Model.save_weights .

Ręczne punkty kontrolne

Ustawiać

Aby pomóc wykazać wszystkie cechy tf.train.Checkpoint , określenie zestawu danych zabawki i etap optymalizacji:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

Utwórz obiekty punktów kontrolnych

Użyj tf.train.Checkpoint obiekt ręcznie utworzyć punkt kontrolny, gdzie obiekty do punktu kontrolnego są ustawione jako atrybuty obiektu.

tf.train.CheckpointManager mogą być również przydatne do zarządzania wieloma punktów kontrolnych.

opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

Trenuj i sprawdzaj model

W poniższej pętli szkolenia tworzy przypadek modelu i optymalizator, a następnie zbiera je w tf.train.Checkpoint obiektu. Wywołuje etap uczenia w pętli na każdej partii danych i okresowo zapisuje punkty kontrolne na dysku.

def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts/ckpt-1
loss 30.37
Saved checkpoint for step 20: ./tf_ckpts/ckpt-2
loss 23.79
Saved checkpoint for step 30: ./tf_ckpts/ckpt-3
loss 17.23
Saved checkpoint for step 40: ./tf_ckpts/ckpt-4
loss 10.72
Saved checkpoint for step 50: ./tf_ckpts/ckpt-5
loss 4.44

Przywróć i kontynuuj trening

Po pierwszym cyklu szkoleniowym możesz przekazać nowy model i menedżera, ale rozpocznij szkolenie dokładnie w miejscu, w którym je przerwałeś:

opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)

train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5
Saved checkpoint for step 60: ./tf_ckpts/ckpt-6
loss 1.33
Saved checkpoint for step 70: ./tf_ckpts/ckpt-7
loss 1.02
Saved checkpoint for step 80: ./tf_ckpts/ckpt-8
loss 0.68
Saved checkpoint for step 90: ./tf_ckpts/ckpt-9
loss 0.34
Saved checkpoint for step 100: ./tf_ckpts/ckpt-10
loss 0.16

tf.train.CheckpointManager obiekt usuwa stare punkty kontrolne. Powyżej jest skonfigurowany tak, aby zachować tylko trzy najnowsze punkty kontrolne.

print(manager.checkpoints)  # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']

Te ścieżki, np './tf_ckpts/ckpt-10' , nie są plikami na dysku. Zamiast tego są prefiksy dla index pliku i jeden lub więcej plików danych, które zawierają wartości zmiennych. Te prefiksy są zgrupowane w jednym checkpoint pliku ( './tf_ckpts/checkpoint' ), gdzie CheckpointManager oszczędza swój stan.

ls ./tf_ckpts
checkpoint           ckpt-8.data-00000-of-00001  ckpt-9.index
ckpt-10.data-00000-of-00001  ckpt-8.index
ckpt-10.index            ckpt-9.data-00000-of-00001

Mechanika ładowania

TensorFlow dopasowuje zmienne do wartości w punktach kontrolnych, przemierzając ukierunkowany graf z nazwanymi krawędziami, zaczynając od ładowanego obiektu. Nazwy krawędziowe zazwyczaj pochodzą z nazw atrybutów w obiektach, na przykład "l1" w self.l1 = tf.keras.layers.Dense(5) . tf.train.Checkpoint wykorzystuje swoje słowa kluczowego nazwy argumentów, tak jak w "step" w tf.train.Checkpoint(step=...) .

Wykres zależności z powyższego przykładu wygląda tak:

Wizualizacja wykresu zależności dla przykładowej pętli treningowej

Optymalizator jest w kolorze czerwonym, zwykłe zmienne w kolorze niebieskim, a zmienne w boksie optymalizatora w kolorze pomarańczowym. Innych węzłów, na przykład, reprezentujących tf.train.Checkpoint -są na czarno.

Zmienne przedziałów są częścią stanu optymalizatora, ale są tworzone dla określonej zmiennej. Na przykład 'm' krawędzie powyżej odpowiadają sile, którą Adam utworów Optimizer dla każdej zmiennej. Zmienne szczelin są zapisywane w punkcie kontrolnym tylko wtedy, gdy zmienna i optymalizator zostałyby zapisane, a więc krawędzie przerywane.

Wywołanie restore on a tf.train.Checkpoint obiektów kolejek wnioskowane odnowa przywrócenie wartości zmiennych tak szybko, jak tam jest ścieżka dopasowywania od Checkpoint obiektu. Na przykład można załadować tylko odchylenie z modelu zdefiniowanego powyżej, rekonstruując jedną ścieżkę do niego przez sieć i warstwę.

to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy())  # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())  # This gets the restored value.
[0. 0. 0. 0. 0.]
[2.9910686 3.8070676 3.252836  4.277522  3.8073184]

Wykres zależności dla tych nowych obiektów jest znacznie mniejszym podgrafem większego punktu kontrolnego, który napisałeś powyżej. Obejmuje ona tylko nastawienie i strzał licznika że tf.train.Checkpoint używa się liczba punktów kontrolnych.

Wizualizacja podwykresu dla zmiennej obciążenia

restore Zwraca obiekt stanu, który ma opcjonalne twierdzeń. Wszystkich obiektów utworzonych w nowej Checkpoint zostały przywrócone, więc status.assert_existing_objects_matched przepustki.

status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f898454bed0>

W punkcie kontrolnym znajduje się wiele obiektów, które nie pasują, w tym jądro warstwy i zmienne optymalizatora. status.assert_consumed przechodzi tylko wtedy, gdy punkt kontrolny i mecz Program dokładnie i będzie wyjątek tutaj.

Opóźnione uzupełnienia

Layer obiektów w TensorFlow może opóźnić tworzenie zmiennych do ich pierwszej rozmowy, gdy kształty wejściowe są dostępne. Na przykład kształt Dense jądrze warstwy zależy zarówno kształty wejściowych i wyjściowych warstwy, a więc kształtu wyjściowego wymagane jako argument konstruktora nie ma wystarczających informacji, aby utworzyć zmienną na własną rękę. Ponieważ wywołanie Layer również odczytuje wartość zmiennej, przywracania musi się zdarzyć między stworzeniem tej zmiennej i jej pierwszego użycia.

W celu wsparcia tego idiomu, tf.train.Checkpoint kolejek odbudowała, które nie mają jeszcze zmienną dopasowanie.

delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())  # Not restored; still zeros
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())  # Restored
[[0. 0. 0. 0. 0.]]
[[4.5344105 4.5929823 4.7816424 4.758177  5.007635 ]]

Ręczne sprawdzanie punktów kontrolnych

tf.train.load_checkpoint zwraca CheckpointReader który daje niższy poziom dostępu do treści punktu kontrolnego. Zawiera mapowania z klucza każdej zmiennej do kształtu i typu d dla każdej zmiennej w punkcie kontrolnym. Kluczem do zmiennej jest ścieżka jej obiektu, tak jak na powyższych wykresach.

reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()

sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH',
 'iterator/.ATTRIBUTES/ITERATOR_STATE',
 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE',
 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE',
 'save_counter/.ATTRIBUTES/VARIABLE_VALUE',
 'step/.ATTRIBUTES/VARIABLE_VALUE']

Więc jeśli jesteś zainteresowany w wartości net.l1.kernel można uzyskać wartość z następującego kodu:

key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'

print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5]
Dtype: float32

Zapewnia również get_tensor metodę pozwalającą na sprawdzenie wartości zmiennej:

reader.get_tensor(key)
array([[4.5344105, 4.5929823, 4.7816424, 4.758177 , 5.007635 ]],
      dtype=float32)

Śledzenie obiektów

Punkty kontrolne zapisywanie i przywracanie wartości tf.Variable obiektów przez „śledzenia” każdej zmiennej lub trackable zestawu obiektów w jednym z jego atrybutów. Podczas wykonywania zapisu zmienne są zbierane rekursywnie ze wszystkich osiągalnych śledzonych obiektów.

Podobnie jak w przypadku bezpośredniego przypisania atrybutów takich jak self.l1 = tf.keras.layers.Dense(5) , przypisywanie list i słowniki atrybutów będzie śledzić ich zawartość.

save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy()  # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

Możesz zauważyć opakowujące obiekty dla list i słowników. Opakowania te są możliwymi do sprawdzenia wersjami podstawowych struktur danych. Podobnie jak ładowanie oparte na atrybutach, te opakowania przywracają wartość zmiennej zaraz po jej dodaniu do kontenera.

restore.listed = []
print(restore.listed)  # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)  # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])

Możliwe do śledzenia obiektów obejmuje tf.train.Checkpoint , tf.Module i podklas (np keras.layers.Layer i keras.Model ) oraz uznane pojemników Pythonie

  • dict (i collections.OrderedDict )
  • list
  • tuple (i collections.namedtuple , typing.NamedTuple )

Inne rodzaje kontenerów nieobsługiwane, w tym:

  • collections.defaultdict
  • set

Wszystkie inne obiekty Pythona są ignorowane, w tym:

  • int
  • string
  • float

Streszczenie

Obiekty TensorFlow zapewniają łatwy automatyczny mechanizm zapisywania i przywracania wartości wykorzystywanych przez nich zmiennych.