Salvar e carregar modelos

O TensorFlow.js fornece funcionalidade para salvar e carregar modelos que foram criados com a API de camadas ou convertidos à partir de modelos TensorFlow.js existentes. Estes podem ser modelos que você treinou ou aqueles treinados por outros. Um dos principais benefícios do uso da API de camadas é que os modelos criados com ela são serializáveis e é isso que exploraremos nesse tutorial.

Este tutorial focará em salvar e carregar modelos TensorFlow.js (identificáveis por arquivos JSON). Nós também podemos importar modelos TensorFlow Python. O carregamento desses modelos é abordado nos dois tutoriais a seguir:

Salvar um tf.Model

tf.Model e tf.Sequential fornecem uma função model.save que permite você salvar a topologia e os pesos de um modelo.

  • Topologia: Este é um arquivo descrevendo a arquitetura de um modelo (ou seja, quais operações ele usa). Ele contém referências aos pesos do modelo que são armazenados externamente.

  • Pesos: Estes são arquivos binários que armazenam os pesos de um determinado modelo em um formato eficiente. Eles são geralmente armazenados na mesma pasta que a topologia.

Vamos dar uma olhada no que o código para salvar um modelo se parece:

const saveResult = await model.save('localstorage://my-model-1');

Algumas coisas a serem observadas:

  • O método save recebe um argumento string semelhante à uma URL que começa com um esquema. Isso descreve o tipo de destino no qual estamos tentado salvar um modelo. No exemplo acima, o esquema é localstorage://.
  • O esquema é seguido por um caminho. No exemplo acima, o caminho é my-model-1.
  • O método save é assíncrono.
  • O valor do retorno de model.save é um objeto JSON que carrega informações como o tamanho de bytes da topologia e dos pesos do modelo.
  • O ambiente usado para salvar o modelo não afeta quais ambientes pode carregá-lo. Salvar um modelo em node.js não impede que ele seja carregado no navegador.

A seguir, examinaremos os diferentes esquemas disponíveis:

Local Storage (Somente navegador)

Esquema: localstorage://

await model.save('localstorage://my-model');

Isso salva um modelo com o nome my-model no local storage do navegador. Isso persistirá entre atualizações, embora o local storage possa ser limpo pelos usuários ou pelo navegador se o espaço se tornar um problema. Cada navegador também define seu próprio limite de quantos dados podem ser armazenados no local storage para um determinado domínio.

IndexedDB (Somente navegador)

Esquema: indexeddb://

await model.save('indexeddb://my-model');

Isso salva um modelo no armazenamento do IndexedDB do navegador. Como o local storage, ele persiste entre atualizações, também tende a ter limites maiores no tamanho dos objetos armazenados.

Downloads de Arquivos (Somente navegador)

Esquema: downloads://

await model.save('downloads://my-model');

Isso fará com que o navegador faça o download dos arquivos do modelo na máquina do usuário. Dois arquivos serão produzidos:

  1. Uma arquivo de JSON chamado [my-model].json, que carrega a topologia e a referência ao arquivo de pesos descrito abaixo.
  2. Um arquivo binário com os valores dos pesos nomeado [my-model].weights.bin.

Você pode mudar o nome [my-model] para obter arquivo com um nome diferente.

Como o arquivo .json aponta para o arquivo .bin usando caminho relativo, os dois arquivos devem estar na mesma pasta.

Nota: alguns navegadores exigem que o usuário conceda permissão antes que mais de um arquivo possa ser baixado ao mesmo tempo.

Requisição HTTP(S)

Esquema: http:// ou https://

await model.save('http://model-server.domain/upload')

Isso criará uma requisição web para salvar um modelo em um servidor remoto. Você deve estar no controle desse servidor remoto para garantir que ele possa lidar com a solicitação.

O modelo será enviado para o servidor HTTP especificado através de uma requisição POST.

O corpo do POST está no formato multipart/form-data e consiste de dois arquivos.

  1. Uma arquivo JSON nomeado model.json, que carrega a topologia e a referência ao arquivo de pesos descrito abaixo.
  2. Um arquivo binário com os valores dos pesos nomeado model.weights.bin.

Observe que os nomes dos arquivos serão sempre como especificado anteriormente (o nome está incorporado à função). Esta documentação da api contém um trecho de código Python que demonstra como alguém pode usar framework web flask para lidar com a solicitação originada pelo save.

Geralmente, você precisará passar mais argumentos ou headers na requisição para o seu servidor HTTP (por exemplo, para autenticação ou se desejar especificar uma pasta na qual o modelo deve ser salvo). Você pode obter controle refinado sobre esses aspectos das solicitações de save substituindo o argumento da string de URL em tf.io.browserHTTPRequest. Esta API oferece maior flexibilidade no controle de solicitações HTTP.

Por exemplo:

await model.save(tf.io.browserHTTPRequest(
    'http://model-server.domain/upload',
    {method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));

Sistema de Arquivo Nativo (Somente Node.js)

Esquema: file://

await model.save('file:///path/to/my-model');

Ao rodar no Node.js, também temos acesso direto ao sistema de arquivos e podemos salvar modelos lá. O comando acima salvará dois arquivos no caminho especificado após o esquema.

  1. Uma arquivo JSON nomeado [model].json, que carrega a topologia e a referência para o arquivo de pesos descrito abaixo.
  2. Um arquivo binário carregando os valores dos pesos nomeado [model].weights.bin.

Observe que o nome dos dois arquivos será sempre exatamente como especificado acima (o nome está incorporado à função).

Carregando um tf.Model

Dado um modelo que foi salvo usando um dos método acima, podemos carregar o modelo usando a API tf.loadLayersModel.

Vamos dar uma olhada em como é um código para carregar um modelo:

const model = await tf.loadLayersModel('localstorage://my-model-1');

Algumas coisas para observar:

  • Assim como model.save(), a função loadLayersModel recebe uma string semelhante à URL que começa com um esquema. Isso descreve o tipo de destino do qual estamos tentando carregar um modelo.
  • O esquema é seguido por um caminho. No exemplo acima, o caminho é my-model-1.
  • A string semelhante à URL pode ser substituída por um objeto que condiz com a interface IOHandler.
  • A função tf.loadLayersModel() é assíncrona.
  • O valor de retorno de tf.loadLayersModel é um tf.Model.

A seguir, examinaremos os diferentes esquemas disponíveis.

Local Storage (Somente no navegador)

Esquema: localstorage://

const model = await tf.loadLayersModel('localstorage://my-model');

Isso carrega um modelo chamado my-model do local storage do navegador.

IndexedDB (Somente no navegador)

Esquema: indexeddb://

const model = await tf.loadLayersModel('indexeddb://my-model');

Isso carrega um modelo do armazenamento do IndexedDB do navegador.

HTTP(S)

Esquema: http:// ou https://

const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');

Isso carrega um modelo de um endpoint http. Depois de carregar um arquivo json, a função fará requisições para os arquivo .bin correspondentes que o arquivo json referencia.

NOTA: Essa implementação depende da presença do método fetch. Se você está em um ambiente que não fornece o método fetch nativamente, você pode fornecer um método global nomeado fetch que satisfaça a interface ou use uma biblioteca como node-fetch.

Sistema de Arquivo Nativo (Somente no Node.j)

Esquema: file://

const model = await tf.loadLayersModel('file://path/to/my-model/model.json');

Ao executar no Node.js, também temos acesso direto ao sistema de arquivo e podemos carregar modelos de lá. Observe que na chamada da função acima, nós referenciamos o próprio arquivo model.json (enquanto que ao salvar, especificamos a pasta). O(s) arquivo(s) .bin correspondente(s) deve estar na mesma pasta do arquivo .json.

Carregando modelos com IOHandlers

Se os esquemas acima não são suficientes para sua necessidade, você pode implementar um comportamento de carregamento personalizado com um IOHandler. Um IOHandler fornecido pelo TensorFlow.js é tf.io.browserFiles que permite os usuários do navegador carregarem arquivos do modelo no navegador. Veja a documentação para mais informação.

Salvando e Carregando Modelos com IOHandlers personalizados.

Se os esquemas acima não são suficientes para carregar ou salvar seu modelo, você precisa implementar um comportamento de serialização personalizado implementando um IOHandler.

Uma IOHandler é um objeto com um método save e load.

A função save recebe um parâmetro que obedece a interface ModelArtifacts e deve retornar uma promise que resolve com um objeto SaveResult.

A função load não recebe parâmetros e deve retornar uma promise que resolve com um objeto ModelArtifacts. Isso é o mesmo objeto que é passado para save.

Veja BrowserHTTPRequest para um exemplo de como implementar um IOHandler.