Pontos de verificação do modelo

A capacidade de salvar e restaurar o estado de um modelo é vital para diversas aplicações, como na aprendizagem por transferência ou para realizar inferências usando modelos pré-treinados. Salvar os parâmetros de um modelo (pesos, desvios, etc.) em um arquivo ou diretório de checkpoint é uma maneira de fazer isso.

Este módulo fornece uma interface de alto nível para carregar e salvar pontos de verificação de formato do TensorFlow v2 , bem como componentes de nível inferior que gravam e leem nesse formato de arquivo.

Carregando e salvando modelos simples

Em conformidade com o protocolo Checkpointable , muitos modelos simples podem ser serializados para pontos de verificação sem qualquer código adicional:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

e então esse mesmo ponto de verificação pode ser lido usando:

try model.readCheckpoint(from: directory, name: "LeNet")

Esta implementação padrão para carregamento e salvamento de modelo usará um esquema de nomenclatura baseado em caminho para cada tensor no modelo que é baseado nos nomes das propriedades dentro das estruturas do modelo. Por exemplo, os pesos e desvios dentro da primeira convolução no modelo LeNet-5 serão salvos com os nomes conv1/filter e conv1/bias , respectivamente. Ao carregar, o leitor de checkpoint irá procurar tensores com esses nomes.

Personalizando o carregamento e salvamento do modelo

Se você deseja ter maior controle sobre quais tensores são salvos e carregados, ou sobre a nomenclatura desses tensores, o protocolo Checkpointable oferece alguns pontos de customização.

Para ignorar propriedades em determinados tipos, você pode fornecer uma implementação de ignoredTensorPaths em seu modelo que retorna um conjunto de strings na forma de Type.property . Por exemplo, para ignorar a propriedade scale em cada camada de Atenção, você poderia retornar ["Attention.scale"] .

Por padrão, uma barra é usada para separar cada nível mais profundo em um modelo. Isso pode ser personalizado implementando checkpointSeparator em seu modelo e fornecendo uma nova string para usar neste separador.

Finalmente, para obter o maior grau de personalização na nomenclatura de tensor, você pode implementar tensorNameMap e fornecer uma função que mapeia do nome de string padrão gerado para um tensor no modelo para um nome de string desejado no ponto de verificação. Mais comumente, isso será usado para interoperar com pontos de verificação gerados com outras estruturas, cada uma com suas próprias convenções de nomenclatura e estruturas de modelo. Uma função de mapeamento customizada oferece o maior grau de customização na forma como esses tensores são nomeados.

Algumas funções auxiliares padrão são fornecidas, como o CheckpointWriter.identityMap padrão (que simplesmente usa o nome do caminho do tensor gerado automaticamente para pontos de verificação) ou a função CheckpointWriter.lookupMap(table:) , que pode construir um mapeamento a partir de um dicionário.

Para obter um exemplo de como o mapeamento personalizado pode ser realizado, consulte o modelo GPT-2 , que usa uma função de mapeamento para corresponder ao esquema de nomenclatura exato usado para os pontos de verificação do OpenAI.

Os componentes CheckpointReader e CheckpointWriter

Para escrita de ponto de verificação, a extensão fornecida pelo protocolo Checkpointable usa reflexão e caminhos-chave para iterar sobre as propriedades de um modelo e gerar um dicionário que mapeia caminhos de tensor de string para valores de tensor. Este dicionário é fornecido a um CheckpointWriter subjacente, juntamente com um diretório no qual gravar o ponto de verificação. Esse CheckpointWriter cuida da tarefa de gerar o ponto de verificação no disco a partir desse dicionário.

O inverso desse processo é a leitura, onde um CheckpointReader recebe a localização de um diretório de ponto de verificação no disco. Em seguida, ele lê esse ponto de verificação e forma um dicionário que mapeia os nomes dos tensores dentro do ponto de verificação com seus valores salvos. Este dicionário é usado para substituir os tensores atuais em um modelo pelos deste dicionário.

Para carregar e salvar, o protocolo Checkpointable mapeia os caminhos da string para tensores para nomes de tensores correspondentes no disco usando a função de mapeamento descrita acima.

Se o protocolo Checkpointable não tiver a funcionalidade necessária ou se desejar mais controle sobre o processo de carregamento e salvamento do ponto de verificação, as classes CheckpointReader e CheckpointWriter poderão ser usadas sozinhas.

O formato do ponto de verificação do TensorFlow v2

O formato de ponto de verificação do TensorFlow v2, conforme brevemente descrito neste cabeçalho , é o formato de segunda geração para pontos de verificação do modelo do TensorFlow. Este formato de segunda geração está em uso desde o final de 2016 e apresenta uma série de melhorias em relação ao formato de ponto de verificação v1. Os SavedModels do TensorFlow usam pontos de verificação v2 dentro deles para salvar os parâmetros do modelo.

Um ponto de verificação do TensorFlow v2 consiste em um diretório com uma estrutura como esta:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

onde o primeiro arquivo armazena os metadados do ponto de verificação e os arquivos restantes são fragmentos binários que contêm os parâmetros serializados do modelo.

O arquivo de metadados de índice contém os tipos, tamanhos, locais e nomes de strings de todos os tensores serializados contidos nos fragmentos. Esse arquivo de índice é a parte estruturalmente mais complexa do ponto de verificação e é baseado em tensorflow::table , que é baseado em SSTable/LevelDB. Este arquivo de índice é composto por uma série de pares chave-valor, onde as chaves são strings e os valores são buffers de protocolo. As strings são classificadas e compactadas por prefixo. Por exemplo: se a primeira entrada for conv1/weight e a próxima conv1/bias , a segunda entrada usará apenas a parte bias .

Às vezes, esse arquivo de índice geral é compactado usando a compactação Snappy . O arquivo SnappyDecompression.swift fornece uma implementação Swift nativa da descompactação Snappy de uma instância de dados compactada.

Os metadados do cabeçalho do índice e os metadados do tensor são codificados como buffers de protocolo e codificados/decodificados diretamente via Swift Protobuf .

As classes CheckpointIndexReader e CheckpointIndexWriter lidam com o carregamento e salvamento desses arquivos de índice como parte das classes abrangentes CheckpointReader e CheckpointWriter . Os últimos usam os arquivos de índice como base para determinar o que ler e gravar nos fragmentos binários estruturalmente mais simples que contêm os dados do tensor.