Посмотреть на TensorFlow.org | Запустить в Google Colab | Посмотреть исходный код на GitHub | Скачать блокнот |
Фраза «Сохранение модели TensorFlow» обычно означает одно из двух:
- Контрольно-пропускные пункты, ИЛИ
- Сохраненная модель.
Контрольные точки фиксируют точное значение всех параметров (объектов tf.Variable
), используемых моделью. Контрольные точки не содержат никакого описания вычислений, определенных моделью, и поэтому обычно полезны только тогда, когда доступен исходный код, который будет использовать сохраненные значения параметров.
С другой стороны, формат SavedModel включает сериализованное описание вычислений, определенных моделью, в дополнение к значениям параметров (контрольная точка). Модели в этом формате не зависят от исходного кода, создавшего модель. Таким образом, они подходят для развертывания через TensorFlow Serving, TensorFlow Lite, TensorFlow.js или программы на других языках программирования (C, C++, Java, Go, Rust, C# и т. д. API-интерфейсы TensorFlow).
В этом руководстве рассматриваются API для записи и чтения контрольных точек.
Настраивать
import tensorflow as tf
class Net(tf.keras.Model):
"""A simple linear model."""
def __init__(self):
super(Net, self).__init__()
self.l1 = tf.keras.layers.Dense(5)
def call(self, x):
return self.l1(x)
net = Net()
Сохранение из обучающих API tf.keras
См. руководство tf.keras
по сохранению и восстановлению файлов .
tf.keras.Model.save_weights
сохраняет контрольную точку TensorFlow.
net.save_weights('easy_checkpoint')
Написание контрольных точек
Постоянное состояние модели TensorFlow хранится в объектах tf.Variable
. Они могут быть созданы напрямую, но часто создаются с помощью высокоуровневых API, таких как tf.keras.layers
или tf.keras.Model
.
Самый простой способ управлять переменными — присоединить их к объектам Python, а затем ссылаться на эти объекты.
Подклассы tf.train.Checkpoint
, tf.keras.layers.Layer
и tf.keras.Model
автоматически отслеживают переменные, назначенные их атрибутам. В следующем примере создается простая линейная модель, а затем записываются контрольные точки, содержащие значения для всех переменных модели.
Вы можете легко сохранить контрольную точку модели с помощью Model.save_weights
.
Ручная проверка
Настраивать
Чтобы продемонстрировать все функции tf.train.Checkpoint
, определите игрушечный набор данных и шаг оптимизации:
def toy_dataset():
inputs = tf.range(10.)[:, None]
labels = inputs * 5. + tf.range(5.)[None, :]
return tf.data.Dataset.from_tensor_slices(
dict(x=inputs, y=labels)).repeat().batch(2)
def train_step(net, example, optimizer):
"""Trains `net` on `example` using `optimizer`."""
with tf.GradientTape() as tape:
output = net(example['x'])
loss = tf.reduce_mean(tf.abs(output - example['y']))
variables = net.trainable_variables
gradients = tape.gradient(loss, variables)
optimizer.apply_gradients(zip(gradients, variables))
return loss
Создание объектов контрольной точки
Используйте объект tf.train.Checkpoint
, чтобы вручную создать контрольную точку, где объекты, которые вы хотите проверить, устанавливаются как атрибуты объекта.
tf.train.CheckpointManager
также может быть полезен для управления несколькими контрольными точками.
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
Обучите и проверьте модель
Следующий обучающий цикл создает экземпляр модели и оптимизатора, а затем собирает их в объект tf.train.Checkpoint
. Он вызывает этап обучения в цикле для каждого пакета данных и периодически записывает контрольные точки на диск.
def train_and_checkpoint(net, manager):
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
print("Restored from {}".format(manager.latest_checkpoint))
else:
print("Initializing from scratch.")
for _ in range(50):
example = next(iterator)
loss = train_step(net, example, opt)
ckpt.step.assign_add(1)
if int(ckpt.step) % 10 == 0:
save_path = manager.save()
print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
print("loss {:1.2f}".format(loss.numpy()))
train_and_checkpoint(net, manager)
Initializing from scratch. Saved checkpoint for step 10: ./tf_ckpts/ckpt-1 loss 31.27 Saved checkpoint for step 20: ./tf_ckpts/ckpt-2 loss 24.68 Saved checkpoint for step 30: ./tf_ckpts/ckpt-3 loss 18.12 Saved checkpoint for step 40: ./tf_ckpts/ckpt-4 loss 11.65 Saved checkpoint for step 50: ./tf_ckpts/ckpt-5 loss 5.39
Восстановить и продолжить обучение
После первого цикла обучения вы можете пройти новую модель и менеджера, но продолжить обучение именно там, где остановились:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
train_and_checkpoint(net, manager)
Restored from ./tf_ckpts/ckpt-5 Saved checkpoint for step 60: ./tf_ckpts/ckpt-6 loss 1.50 Saved checkpoint for step 70: ./tf_ckpts/ckpt-7 loss 1.27 Saved checkpoint for step 80: ./tf_ckpts/ckpt-8 loss 0.56 Saved checkpoint for step 90: ./tf_ckpts/ckpt-9 loss 0.70 Saved checkpoint for step 100: ./tf_ckpts/ckpt-10 loss 0.35
Объект tf.train.CheckpointManager
удаляет старые контрольные точки. Выше он настроен на сохранение только трех последних контрольных точек.
print(manager.checkpoints) # List the three remaining checkpoints
['./tf_ckpts/ckpt-8', './tf_ckpts/ckpt-9', './tf_ckpts/ckpt-10']
Эти пути, например './tf_ckpts/ckpt-10'
, не являются файлами на диске. Вместо этого они являются префиксами для index
файла и одного или нескольких файлов данных, содержащих значения переменных. Эти префиксы сгруппированы вместе в одном файле checkpoint
( './tf_ckpts/checkpoint'
), где CheckpointManager
сохраняет свое состояние.
ls ./tf_ckpts
checkpoint ckpt-8.data-00000-of-00001 ckpt-9.index ckpt-10.data-00000-of-00001 ckpt-8.index ckpt-10.index ckpt-9.data-00000-of-00001
Механика загрузки
TensorFlow сопоставляет переменные со значениями контрольных точек, проходя по ориентированному графу с именованными ребрами, начиная с загружаемого объекта. Имена ребер обычно берутся из имен атрибутов в объектах, например, "l1"
в self.l1 = tf.keras.layers.Dense(5)
. tf.train.Checkpoint
использует свои имена аргументов ключевого слова, как в "step"
в tf.train.Checkpoint(step=...)
.
Граф зависимостей из примера выше выглядит так:
Оптимизатор выделен красным, обычные переменные — синим, а переменные слота оптимизатора — оранжевым. Другие узлы, например, представляющие tf.train.Checkpoint
, выделены черным цветом.
Переменные слота являются частью состояния оптимизатора, но создаются для конкретной переменной. Например, ребра 'm'
выше соответствуют импульсу, который оптимизатор Адама отслеживает для каждой переменной. Переменные слота сохраняются в контрольной точке только в том случае, если и переменная, и оптимизатор должны быть сохранены, поэтому ребра заштрихованы.
Вызов restore
объекта tf.train.Checkpoint
в очередь запрошенные восстановления, восстанавливая значения переменных, как только появляется соответствующий путь от объекта Checkpoint
. Например, вы можете загрузить только смещение из модели, которую вы определили выше, реконструируя один путь к ней через сеть и слой.
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # All zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy()) # This gets the restored value.
[0. 0. 0. 0. 0.] [2.7209885 3.7588918 4.421351 4.1466427 4.0712557]
Граф зависимостей для этих новых объектов является гораздо меньшим подграфом более крупной контрольной точки, которую вы написали выше. Он включает только смещение и счетчик сохранений, которые tf.train.Checkpoint
использует для нумерации контрольных точек.
restore
возвращает объект состояния, который имеет необязательные утверждения. Все объекты, созданные в новой Checkpoint
точке, были восстановлены, поэтому status.assert_existing_objects_matched
проходит.
status.assert_existing_objects_matched()
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f93a075b9d0>
В контрольной точке много объектов, которые не совпали, включая ядро слоя и переменные оптимизатора. status.assert_consumed
проходит только в том случае, если контрольная точка и программа точно совпадают, и здесь будет выдано исключение.
Отсроченные реставрации
Объекты Layer
в TensorFlow могут откладывать создание переменных до их первого вызова, когда доступны входные формы. Например, форма ядра Dense
слоя зависит как от входных, так и от выходных форм слоя, поэтому выходной формы, требуемой в качестве аргумента конструктора, недостаточно для создания переменной самой по себе. Поскольку при вызове Layer
также считывается значение переменной, между созданием переменной и ее первым использованием должно произойти восстановление.
Для поддержки этой идиомы tf.train.Checkpoint
откладывает восстановление, для которого еще нет соответствующей переменной.
deferred_restore = tf.Variable(tf.zeros([1, 5]))
print(deferred_restore.numpy()) # Not restored; still zeros
fake_layer.kernel = deferred_restore
print(deferred_restore.numpy()) # Restored
[[0. 0. 0. 0. 0.]] [[4.5854754 4.607731 4.649179 4.8474874 5.121 ]]
Проверка контрольно-пропускных пунктов вручную
tf.train.load_checkpoint
возвращает CheckpointReader
, который предоставляет доступ нижнего уровня к содержимому контрольной точки. Он содержит сопоставления ключа каждой переменной с формой и типом каждой переменной в контрольной точке. Ключ переменной — это путь к ее объекту, как показано на графиках выше.
reader = tf.train.load_checkpoint('./tf_ckpts/')
shape_from_key = reader.get_variable_to_shape_map()
dtype_from_key = reader.get_variable_to_dtype_map()
sorted(shape_from_key.keys())
['_CHECKPOINTABLE_OBJECT_GRAPH', 'iterator/.ATTRIBUTES/ITERATOR_STATE', 'net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', 'net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', 'optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', 'save_counter/.ATTRIBUTES/VARIABLE_VALUE', 'step/.ATTRIBUTES/VARIABLE_VALUE']
Итак, если вас интересует значение net.l1.kernel
, вы можете получить его с помощью следующего кода:
key = 'net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE'
print("Shape:", shape_from_key[key])
print("Dtype:", dtype_from_key[key].name)
Shape: [1, 5] Dtype: float32
Он также предоставляет метод get_tensor
, позволяющий вам проверять значение переменной:
reader.get_tensor(key)
array([[4.5854754, 4.607731 , 4.649179 , 4.8474874, 5.121 ]], dtype=float32)
Отслеживание объектов
Контрольные точки сохраняют и восстанавливают значения объектов tf.Variable
, «отслеживая» любую переменную или отслеживаемый объект, установленный в одном из его атрибутов. При выполнении сохранения переменные рекурсивно собираются из всех доступных отслеживаемых объектов.
Как и в случае прямого назначения атрибутов, например self.l1 = tf.keras.layers.Dense(5)
, назначение списков и словарей атрибутам будет отслеживать их содержимое.
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')
restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # Not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()
Вы можете заметить объекты-оболочки для списков и словарей. Эти оболочки являются версиями базовых структур данных с контрольными точками. Как и при загрузке на основе атрибутов, эти оболочки восстанавливают значение переменной, как только она добавляется в контейнер.
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1) # Restores v1, from restore() in the previous cell
assert 1. == v1.numpy()
ListWrapper([])
Отслеживаемые объекты включают tf.train.Checkpoint
, tf.Module
и его подклассы (например keras.layers.Layer
и keras.Model
), а также распознанные контейнеры Python:
-
dict
(иcollections.OrderedDict
) -
list
-
tuple
(иcollections.namedtuple
,typing.NamedTuple
)
Другие типы контейнеров не поддерживаются , в том числе:
-
collections.defaultdict
-
set
Все остальные объекты Python игнорируются , в том числе:
-
int
-
string
-
float
Резюме
Объекты TensorFlow предоставляют простой автоматический механизм для сохранения и восстановления значений переменных, которые они используют.