Importar un modelo de TensorFlow a TensorFlow.js

Los modelos basados ​​en TensorFlow GraphDef (normalmente creados a través de la API de Python) se pueden guardar en uno de los siguientes formatos:

  1. Modelo guardado de TensorFlow
  2. Modelo congelado
  3. Módulo Tensorflow Hub

Todos los formatos anteriores se pueden convertir mediante el convertidor TensorFlow.js a un formato que se puede cargar directamente en TensorFlow.js para inferencia.

(Nota: TensorFlow ha dejado de usar el formato de paquete de sesión. Migre sus modelos al formato SavedModel).

Requisitos

El procedimiento de conversión requiere un entorno Python; es posible que desee mantener uno aislado usando pipenv o virtualenv .

Para instalar el convertidor, ejecute el siguiente comando:

 pip install tensorflowjs

Importar un modelo de TensorFlow a TensorFlow.js es un proceso de dos pasos. Primero, convierta un modelo existente al formato web TensorFlow.js y luego cárguelo en TensorFlow.js.

Paso 1. Convertir un modelo de TensorFlow existente al formato web TensorFlow.js

Ejecute el script de conversión proporcionado por el paquete pip:

Ejemplo de modelo guardado:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Ejemplo de modelo congelado:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Ejemplo de módulo Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumentos posicionales Descripción
input_path Ruta completa del directorio del modelo guardado, directorio del paquete de sesiones, archivo de modelo congelado o identificador o ruta del módulo TensorFlow Hub.
output_path Ruta de acceso para todos los artefactos de salida.
Opciones Descripción
--input_format El formato del modelo de entrada. Utilice tf_saved_model para SavedModel, tf_frozen_model para modelo congelado, tf_session_bundle para paquete de sesión, tf_hub para módulo TensorFlow Hub y keras para Keras HDF5.
--output_node_names Los nombres de los nodos de salida, separados por comas.
--saved_model_tags Solo aplicable a la conversión de SavedModel. Etiquetas del MetaGraphDef a cargar, en formato separado por comas. Valores predeterminados para serve .
--signature_name Solo aplicable a la conversión del módulo TensorFlow Hub, firma para cargar. Valores predeterminados por default . Ver https://www.tensorflow.org/hub/common_signatures/

Utilice el siguiente comando para obtener un mensaje de ayuda detallado:

tensorflowjs_converter --help

Archivos generados por el convertidor

El script de conversión anterior produce dos tipos de archivos:

  • model.json : el gráfico de flujo de datos y el manifiesto de peso.
  • group1-shard\*of\* : una colección de archivos de peso binarios

Por ejemplo, aquí está el resultado de la conversión de MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Paso 2: cargar y ejecutar en el navegador

  1. Instale el paquete npm tfjs-converter:

yarn add @tensorflow/tfjs o npm install @tensorflow/tfjs

  1. Cree una instancia de la clase FrozenModel y ejecute la inferencia.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Consulte la demostración de MobileNet .

La API loadGraphModel acepta un parámetro LoadOptions adicional, que se puede utilizar para enviar credenciales o encabezados personalizados junto con la solicitud. Para obtener más información, consulte la documentación de loadGraphModel() .

Operaciones soportadas

Actualmente, TensorFlow.js admite un conjunto limitado de operaciones de TensorFlow. Si su modelo utiliza una operación no compatible, el script tensorflowjs_converter fallará e imprimirá una lista de las operaciones no compatibles en su modelo. Presente un problema para cada operación para informarnos para qué operaciones necesita soporte.

Cargando solo las pesas

Si prefiere cargar solo los pesos, puede utilizar el siguiente fragmento de código:

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...