Guardar y cargar modelos

TensorFlow.js proporciona funcionalidad para guardar y cargar modelos que se crearon con Layers API o se convirtieron a partir de modelos de TensorFlow existentes. Estos pueden ser modelos que usted mismo haya entrenado o aquellos entrenados por otros. Un beneficio clave de usar la API de Layers es que los modelos creados con ella son serializables y esto es lo que exploraremos en este tutorial.

Este tutorial se centrará en guardar y cargar modelos de TensorFlow.js (identificables por archivos JSON). También podemos importar modelos de TensorFlow Python. La carga de estos modelos se trata en los dos tutoriales siguientes:

Guardar un tf.Model

tf.Model y tf.Sequential proporcionan una función model.save que le permite guardar la topología y los pesos de un modelo.

  • Topología: este es un archivo que describe la arquitectura de un modelo (es decir, qué operaciones utiliza). Contiene referencias a los pesos de los modelos que se almacenan externamente.

  • Pesos: Son archivos binarios que almacenan los pesos de un modelo determinado en un formato eficiente. Generalmente se almacenan en la misma carpeta que la topología.

Echemos un vistazo a cómo se ve el código para guardar un modelo.

const saveResult = await model.save('localstorage://my-model-1');

Algunas cosas a tener en cuenta:

  • El método save toma un argumento de cadena similar a una URL que comienza con un esquema . Esto describe el tipo de destino en el que estamos intentando guardar un modelo. En el ejemplo anterior, el esquema es localstorage://
  • Al esquema le sigue un camino . En el ejemplo anterior, la ruta es my-model-1 .
  • El método save es asincrónico.
  • El valor de retorno de model.save es un objeto JSON que contiene información como los tamaños de bytes de la topología y los pesos del modelo.
  • El entorno utilizado para guardar el modelo no afecta qué entornos pueden cargar el modelo. Guardar un modelo en node.js no impide que se cargue en el navegador.

A continuación examinaremos los diferentes esquemas disponibles.

Almacenamiento local (solo navegador)

Esquema: localstorage://

await model.save('localstorage://my-model');

Esto guarda un modelo con el nombre my-model en el almacenamiento local del navegador. Esto persistirá entre actualizaciones, aunque los usuarios o el propio navegador pueden borrar el almacenamiento local si el espacio se convierte en una preocupación. Cada navegador también establece su propio límite sobre la cantidad de datos que se pueden almacenar en el almacenamiento local para un dominio determinado.

IndexedDB (solo navegador)

Esquema: indexeddb://

await model.save('indexeddb://my-model');

Esto guarda un modelo en el almacenamiento IndexedDB del navegador. Al igual que el almacenamiento local, persiste entre actualizaciones y también tiende a tener límites mayores en el tamaño de los objetos almacenados.

Descargas de archivos (solo navegador)

Esquema: downloads://

await model.save('downloads://my-model');

Esto hará que el navegador descargue los archivos del modelo en la máquina del usuario. Se producirán dos archivos:

  1. Un archivo JSON de texto llamado [my-model].json , que contiene la topología y la referencia al archivo de pesos que se describe a continuación.
  2. Un archivo binario que contiene los valores de peso denominado [my-model].weights.bin .

Puede cambiar el nombre [my-model] para obtener archivos con un nombre diferente.

Debido a que el archivo .json apunta al .bin usando una ruta relativa, los dos archivos deben estar en la misma carpeta.

Solicitud HTTP(S)

Esquema: http:// o https://

await model.save('http://model-server.domain/upload')

Esto creará una solicitud web para guardar un modelo en un servidor remoto. Debe tener el control de ese servidor remoto para asegurarse de que pueda manejar la solicitud.

El modelo se enviará al servidor HTTP especificado mediante una solicitud POST . El cuerpo del POST está en formato multipart/form-data y consta de dos archivos.

  1. Un archivo JSON de texto denominado model.json , que contiene la topología y la referencia al archivo de pesos que se describe a continuación.
  2. Un archivo binario que contiene los valores de peso denominado model.weights.bin .

Tenga en cuenta que el nombre de los dos archivos siempre será exactamente el especificado anteriormente (el nombre está integrado en la función). Este documento API contiene un fragmento de código Python que demuestra cómo se puede usar el marco web flask para manejar la solicitud originada desde save .

A menudo tendrá que pasar más argumentos o solicitar encabezados a su servidor HTTP (por ejemplo, para autenticación o si desea especificar una carpeta en la que se debe guardar el modelo). Puede obtener un control detallado sobre estos aspectos de las solicitudes desde save reemplazando el argumento de la cadena URL en tf.io.browserHTTPRequest . Esta API ofrece una mayor flexibilidad en el control de solicitudes HTTP.

Por ejemplo:

await model.save(tf.io.browserHTTPRequest(
    'http://model-server.domain/upload',
    {method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));

Sistema de archivos nativo (solo Node.js)

Esquema: file://

await model.save('file:///path/to/my-model');

Cuando ejecutamos Node.js también tenemos acceso directo al sistema de archivos y podemos guardar modelos allí. El comando anterior guardará dos archivos en la path especificada después del scheme .

  1. Un archivo JSON de texto denominado [model].json , que contiene la topología y la referencia al archivo de pesos que se describe a continuación.
  2. Un archivo binario que contiene los valores de peso denominado [model].weights.bin .

Tenga en cuenta que el nombre de los dos archivos siempre será exactamente el especificado anteriormente (el nombre está integrado en la función).

Cargando un tf.Model

Dado un modelo que se guardó usando uno de los métodos anteriores, podemos cargarlo usando la API tf.loadLayersModel .

Echemos un vistazo a cómo se ve el código para cargar un modelo.

const model = await tf.loadLayersModel('localstorage://my-model-1');

Algunas cosas a tener en cuenta:

  • Al igual que model.save() , la función loadLayersModel toma un argumento de cadena similar a una URL que comienza con un esquema . Esto describe el tipo de destino desde el que intentamos cargar un modelo.
  • Al esquema le sigue un camino . En el ejemplo anterior, la ruta es my-model-1 .
  • La cadena similar a una URL se puede reemplazar por un objeto que coincida con la interfaz IOHandler.
  • La función tf.loadLayersModel() es asíncrona.
  • El valor de retorno de tf.loadLayersModel es tf.Model

A continuación examinaremos los diferentes esquemas disponibles.

Almacenamiento local (solo navegador)

Esquema: localstorage://

const model = await tf.loadLayersModel('localstorage://my-model');

Esto carga un modelo llamado my-model desde el almacenamiento local del navegador.

IndexedDB (solo navegador)

Esquema: indexeddb://

const model = await tf.loadLayersModel('indexeddb://my-model');

Esto carga un modelo desde el almacenamiento IndexedDB del navegador.

HTTP(S)

Esquema: http:// o https://

const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');

Esto carga un modelo desde un punto final http. Después de cargar el archivo json , la función realizará solicitudes de los archivos .bin correspondientes a los que hace referencia el archivo json .

Sistema de archivos nativo (solo Node.js)

Esquema: file://

const model = await tf.loadLayersModel('file://path/to/my-model/model.json');

Cuando ejecutamos Node.js también tenemos acceso directo al sistema de archivos y podemos cargar modelos desde allí. Tenga en cuenta que en la llamada a la función anterior hacemos referencia al archivo model.json en sí (mientras que al guardar especificamos una carpeta). Los archivos .bin correspondientes deben estar en la misma carpeta que el archivo json .

Cargando modelos con IOHandlers

Si los esquemas anteriores no son suficientes para sus necesidades, puede implementar un comportamiento de carga personalizado con un IOHandler . Un IOHandler que proporciona TensorFlow.js es tf.io.browserFiles , que permite a los usuarios del navegador cargar archivos de modelo en el navegador. Consulte la documentación para obtener más información.

Guardar y cargar modelos con IOHandlers personalizados

Si los esquemas anteriores no son suficientes para sus necesidades de carga o guardado, puede implementar un comportamiento de serialización personalizado implementando un IOHandler .

Un IOHandler es un objeto con un método save y load .

La función save toma un parámetro que coincide con la interfaz ModelArtifacts y debe devolver una promesa que se resuelve en un objeto SaveResult .

La función load no toma parámetros y debe devolver una promesa que se resuelve en un objeto ModelArtifacts . Este es el mismo objeto que se pasa para save .

Consulte BrowserHTTPRequest para ver un ejemplo de cómo implementar un IOHandler.