Importar un modelo de Keras a TensorFlow.js

Modelos Keras (normalmente creadas a través de la API de Python) se pueden guardar en uno de varios formatos . El formato de "modelo completo" se puede convertir al formato de capas de TensorFlow.js, que se puede cargar directamente en TensorFlow.js para inferencias o para entrenamiento adicional.

El formato Capas TensorFlow.js de destino es un directorio que contiene un model.json archivo y un conjunto de archivos fragmentados peso en formato binario. El model.json archivo contiene tanto el modelo de topología (también conocido como "arquitectura" o "gráfico": una descripción de las capas y cómo están conectados) y un manifiesto de los archivos de peso.

Requisitos

El procedimiento de conversión requiere un entorno Python; es posible que desee mantener un aislado usando uno pipenv o virtualenv . Para instalar el convertidor, el uso pip install tensorflowjs .

Importar un modelo de Keras a TensorFlow.js es un proceso de dos pasos. Primero, convierta un modelo de Keras existente al formato TF.js Layers y luego cárguelo en TensorFlow.js.

Paso 1. Convierta un modelo de Keras existente al formato TF.js Layers

Modelos Keras se suelen guardar través model.save(filepath) , que produce un único archivo HDF5 (.h5) que contiene tanto la topología de modelo y de los pesos. Para convertir un archivo de este tipo de formato TF.js Capas, ejecute el siguiente comando, donde path/to/my_model.h5 es archivo y la fuente Keras .h5 path/to/tfjs_target_dir es el directorio de salida para los archivos TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternativa: use la API de Python para exportar directamente al formato TF.js Layers

Si tiene un modelo de Keras en Python, puede exportarlo directamente al formato de capas de TensorFlow.js de la siguiente manera:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Paso 2: carga el modelo en TensorFlow.js

Utilice un servidor web para proporcionar los archivos convertidos modelo que ha generado en el paso 1. Tenga en cuenta que puede que tenga que configurar el servidor para permitir el origen cruzado de intercambio de recursos (CORS) , con el fin de permitir que ir a buscar los archivos en JavaScript.

Luego, cargue el modelo en TensorFlow.js proporcionando la URL al archivo model.json:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Ahora el modelo está listo para inferencia, evaluación o reentrenamiento. Por ejemplo, el modelo cargado se puede usar inmediatamente para hacer una predicción:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

Muchos de los TensorFlow.js Ejemplos toma este enfoque, el uso de modelos pretrained que han sido convertidos y alojados en Google Cloud Storage.

Tenga en cuenta que usted se refiere a todo el modelo utilizando el model.json nombre de archivo. loadModel(...) va a buscar model.json , y luego realiza peticiones HTTP adicional (S) para obtener los archivos fragmentados de peso que se hace referencia en el model.json manifiesta peso. Este enfoque permite que todos estos archivos para ser almacenado en caché por el navegador (y tal vez por los servidores de almacenamiento en caché adicionales en el Internet), debido a que el model.json y los fragmentos de peso son cada uno menor que el límite típico tamaño del archivo caché. Por lo tanto, es probable que un modelo se cargue más rápidamente en ocasiones posteriores.

Funciones admitidas

Actualmente, TensorFlow.js Layers solo admite modelos de Keras que utilizan construcciones estándar de Keras. Los modelos que utilizan capas o operaciones no compatibles, por ejemplo, capas personalizadas, capas Lambda, pérdidas personalizadas o métricas personalizadas, no se pueden importar automáticamente porque dependen del código Python que no se puede traducir de manera confiable a JavaScript.