Importando un modelo de Keras a TensorFlow.js

Los modelos de Keras (normalmente creados a través de la API de Python) se pueden guardar en uno de varios formatos . El formato de "modelo completo" se puede convertir al formato de capas de TensorFlow.js, que se puede cargar directamente en TensorFlow.js para inferencia o para capacitación adicional.

El formato de capas TensorFlow.js de destino es un directorio que contiene un archivo model.json y un conjunto de archivos de peso fragmentados en formato binario. El archivo model.json contiene la topología del modelo (también conocido como "arquitectura" o "gráfico": una descripción de las capas y cómo están conectadas) y un manifiesto de los archivos de peso.

Requisitos

El procedimiento de conversión requiere un entorno Python; es posible que desee mantener uno aislado usando pipenv o virtualenv . Para instalar el convertidor, use pip install tensorflowjs .

Importar un modelo de Keras a TensorFlow.js es un proceso de dos pasos. Primero, convierta un modelo Keras existente al formato TF.js Layers y luego cárguelo en TensorFlow.js.

Paso 1. Convertir un modelo Keras existente al formato TF.js Layers

Los modelos Keras generalmente se guardan a través de model.save(filepath) , que produce un único archivo HDF5 (.h5) que contiene tanto la topología del modelo como los pesos. Para convertir dicho archivo al formato TF.js Layers, ejecute el siguiente comando, donde path/to/my_model.h5 es el archivo Keras .h5 de origen y path/to/tfjs_target_dir es el directorio de salida de destino para los archivos TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternativa: use la API de Python para exportar directamente al formato de capas TF.js

Si tiene un modelo de Keras en Python, puede exportarlo directamente al formato de capas de TensorFlow.js de la siguiente manera:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Paso 2: carga el modelo en TensorFlow.js

Utilice un servidor web para servir los archivos de modelo convertidos que generó en el Paso 1. Tenga en cuenta que es posible que deba configurar su servidor para permitir el uso compartido de recursos entre orígenes (CORS) , a fin de permitir la recuperación de los archivos en JavaScript.

Luego cargue el modelo en TensorFlow.js proporcionando la URL del archivo model.json:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Ahora el modelo está listo para la inferencia, evaluación o reentrenamiento. Por ejemplo, el modelo cargado se puede utilizar inmediatamente para hacer una predicción:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

Muchos de los ejemplos de TensorFlow.js adoptan este enfoque y utilizan modelos previamente entrenados que se han convertido y alojado en Google Cloud Storage.

Tenga en cuenta que hace referencia al modelo completo utilizando el nombre de archivo model.json . loadModel(...) recupera model.json y luego realiza solicitudes HTTP(S) adicionales para obtener los archivos de peso fragmentados a los que se hace referencia en el manifiesto de peso model.json . Este enfoque permite que el navegador almacene en caché todos estos archivos (y tal vez mediante servidores de almacenamiento en caché adicionales en Internet), porque el model.json y los fragmentos de peso son cada uno más pequeños que el límite de tamaño de archivo de caché típico. Por tanto, es probable que un modelo se cargue más rápidamente en ocasiones posteriores.

Funciones compatibles

Actualmente, TensorFlow.js Layers solo admite modelos Keras que utilizan construcciones Keras estándar. Los modelos que utilizan operaciones o capas no compatibles (por ejemplo, capas personalizadas, capas Lambda, pérdidas personalizadas o métricas personalizadas) no se pueden importar automáticamente porque dependen de código Python que no se puede traducir de manera confiable a JavaScript.