Контрольные точки модели

Возможность сохранять и восстанавливать состояние модели жизненно важна для ряда приложений, например, для трансферного обучения или для выполнения логических выводов с использованием предварительно обученных моделей. Сохранение параметров модели (веса, смещения и т. д.) в файле или каталоге контрольных точек — один из способов добиться этого.

Этот модуль предоставляет высокоуровневый интерфейс для загрузки и сохранения контрольных точек формата TensorFlow v2 , а также компоненты более низкого уровня, которые записывают и читают из этого формата файла.

Загрузка и сохранение простых моделей

Соответствуя протоколу Checkpointable , многие простые модели можно сериализовать в контрольные точки без какого-либо дополнительного кода:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

и затем эту же контрольную точку можно прочитать, используя:

try model.readCheckpoint(from: directory, name: "LeNet")

Эта реализация по умолчанию для загрузки и сохранения модели будет использовать схему именования на основе пути для каждого тензора в модели, основанную на именах свойств в структурах модели. Например, веса и смещения в первой свертке в модели LeNet-5 будут сохранены под именами conv1/filter и conv1/bias соответственно. При загрузке считыватель контрольных точек будет искать тензоры с этими именами.

Настройка загрузки и сохранения модели

Если вы хотите иметь больший контроль над тем, какие тензоры сохраняются и загружаются, или над именованием этих тензоров, протокол Checkpointable предлагает несколько возможностей настройки.

Чтобы игнорировать свойства определенных типов, вы можете предоставить реализацию ignoredTensorPaths в своей модели, которая возвращает набор строк в форме Type.property . Например, чтобы игнорировать свойство scale на каждом слое внимания, вы можете вернуть ["Attention.scale"] .

По умолчанию косая черта используется для разделения каждого более глубокого уровня модели. Это можно настроить, реализовав checkpointSeparator в вашей модели и предоставив новую строку для этого разделителя.

Наконец, для максимальной настройки именования тензоров вы можете реализовать tensorNameMap и предоставить функцию, которая сопоставляет имя строки по умолчанию, сгенерированное для тензора в модели, с желаемым именем строки в контрольной точке. Чаще всего это будет использоваться для взаимодействия с контрольными точками, созданными с помощью других платформ, каждая из которых имеет свои собственные соглашения об именах и структуры моделей. Пользовательская функция сопоставления обеспечивает максимальную степень настройки имен этих тензоров.

Предоставляются некоторые стандартные вспомогательные функции, такие как функция CheckpointWriter.identityMap по умолчанию (которая просто использует автоматически созданное имя тензорного пути для контрольных точек) или функция CheckpointWriter.lookupMap(table:) , которая может создавать сопоставления из словаря.

Пример того, как можно выполнить пользовательское сопоставление, см. в модели GPT-2 , в которой используется функция сопоставления, соответствующая точной схеме именования, используемой для контрольных точек OpenAI.

Компоненты CheckpointReader и CheckpointWriter.

Для записи контрольных точек расширение, предоставляемое протоколом Checkpointable , использует отражение и ключевые пути для перебора свойств модели и создания словаря, который сопоставляет пути тензора строк со значениями Tensor. Этот словарь предоставляется базовому CheckpointWriter вместе с каталогом, в котором можно записать контрольную точку. Этот CheckpointWriter выполняет задачу создания контрольной точки на диске из этого словаря.

Обратным процессом является чтение, при котором CheckpointReader получает местоположение каталога контрольных точек на диске. Затем он считывает данные из этой контрольной точки и формирует словарь, который сопоставляет имена тензоров внутри контрольной точки с их сохраненными значениями. Этот словарь используется для замены текущих тензоров в модели тензорами из этого словаря.

Как для загрузки, так и для сохранения протокол Checkpointable сопоставляет строковые пути к тензорам соответствующим именам тензоров на диске, используя описанную выше функцию сопоставления.

Если протоколу Checkpointable не хватает необходимой функциональности или требуется больший контроль над процессом загрузки и сохранения контрольной точки, классы CheckpointReader и CheckpointWriter можно использовать отдельно.

Формат контрольной точки TensorFlow v2

Формат контрольных точек TensorFlow v2, кратко описанный в этом заголовке , является форматом второго поколения для контрольных точек модели TensorFlow. Этот формат второго поколения используется с конца 2016 года и имеет ряд улучшений по сравнению с форматом контрольных точек v1. TensorFlow SavedModels использует внутри себя контрольные точки v2 для сохранения параметров модели.

Контрольная точка TensorFlow v2 состоит из каталога со следующей структурой:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

где первый файл хранит метаданные для контрольной точки, а остальные файлы представляют собой двоичные фрагменты, содержащие сериализованные параметры модели.

Файл метаданных индекса содержит типы, размеры, местоположения и имена строк всех сериализованных тензоров, содержащихся в сегментах. Этот индексный файл является наиболее структурно сложной частью контрольной точки и основан на tensorflow::table , который сам по себе основан на SSTable/LevelDB. Этот индексный файл состоит из серии пар ключ-значение, где ключи представляют собой строки, а значения — буферы протокола. Строки сортируются и сжимаются по префиксу. Например: если первая запись — conv1/weight , а следующая conv1/bias , вторая запись использует только часть bias .

Этот общий индексный файл иногда сжимается с помощью сжатия Snappy . Файл SnappyDecompression.swift предоставляет встроенную реализацию Swift распаковки Snappy из сжатого экземпляра данных.

Метаданные заголовка индекса и метаданные тензора кодируются как буферы протокола и кодируются/декодируются напрямую через Swift Protobuf .

Классы CheckpointIndexReader и CheckpointIndexWriter обрабатывают загрузку и сохранение этих индексных файлов как часть общих классов CheckpointReader и CheckpointWriter . Последние используют индексные файлы в качестве основы для определения того, что следует читать и записывать в структурно более простые двоичные фрагменты, содержащие тензорные данные.