Importando um modelo Keras para TensorFlow.js

Os modelos Keras (normalmente criados por meio da API Python) podem ser salvos em um dos vários formatos . O formato "modelo inteiro" pode ser convertido para o formato TensorFlow.js Layers, que pode ser carregado diretamente no TensorFlow.js para inferência ou treinamento adicional.

O formato de camadas TensorFlow.js de destino é um diretório que contém um arquivo model.json e um conjunto de arquivos de peso fragmentado em formato binário. O arquivo model.json contém a topologia do modelo (também conhecida como "arquitetura" ou "grafo": uma descrição das camadas e como elas estão conectadas) e um manifesto dos arquivos de peso.

Requisitos

O procedimento de conversão requer um ambiente Python; você pode querer manter um isolado usando pipenv ou virtualenv . Para instalar o conversor, use pip install tensorflowjs .

A importação de um modelo Keras para o TensorFlow.js é um processo de duas etapas. Primeiro, converta um modelo Keras existente para o formato TF.js Layers e carregue-o no TensorFlow.js.

Etapa 1. Converter um modelo Keras existente para o formato TF.js Layers

Os modelos Keras geralmente são salvos via model.save(filepath) , que produz um único arquivo HDF5 (.h5) contendo a topologia do modelo e os pesos. Para converter esse arquivo para o formato TF.js Layers, execute o seguinte comando, em que path/to/my_model.h5 é o arquivo Keras .h5 de origem e path/to/tfjs_target_dir é o diretório de saída de destino para os arquivos TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternativa: use a API Python para exportar diretamente para o formato TF.js Layers

Se você tiver um modelo Keras em Python, poderá exportá-lo diretamente para o formato TensorFlow.js Layers da seguinte maneira:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Etapa 2: carregar o modelo no TensorFlow.js

Use um servidor da Web para servir os arquivos de modelo convertidos que você gerou na Etapa 1. Observe que talvez seja necessário configurar seu servidor para permitir o Compartilhamento de Recursos de Origem Cruzada (CORS) , para permitir a busca dos arquivos em JavaScript.

Em seguida, carregue o modelo no TensorFlow.js fornecendo o URL para o arquivo model.json:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Agora o modelo está pronto para inferência, avaliação ou retreinamento. Por exemplo, o modelo carregado pode ser usado imediatamente para fazer uma previsão:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

Muitos dos exemplos do TensorFlow.js usam essa abordagem, usando modelos pré-treinados que foram convertidos e hospedados no Google Cloud Storage.

Observe que você se refere ao modelo inteiro usando o nome de arquivo model.json . loadModel(...) busca model.json e, em seguida, faz solicitações HTTP(S) adicionais para obter os arquivos de peso fragmentado referenciados no manifesto de peso model.json . Essa abordagem permite que todos esses arquivos sejam armazenados em cache pelo navegador (e talvez por servidores de cache adicionais na Internet), porque o model.json e os fragmentos de peso são menores do que o limite de tamanho de arquivo de cache típico. Assim, é provável que um modelo seja carregado mais rapidamente em ocasiões subsequentes.

Recursos compatíveis

Atualmente, as camadas do TensorFlow.js são compatíveis apenas com modelos Keras usando construções Keras padrão. Modelos que usam operações ou camadas não suportadas – por exemplo, camadas personalizadas, camadas Lambda, perdas personalizadas ou métricas personalizadas – não podem ser importados automaticamente, porque dependem do código Python que não pode ser traduzido de forma confiável em JavaScript.