Entrenar un modelo usando un trabajador web

En este tutorial, explorará una aplicación web de ejemplo que utiliza un trabajador web para entrenar una red neuronal recurrente (RNN) para realizar sumas de enteros. La aplicación de ejemplo no define explícitamente el operador de suma. En cambio, entrena el RNN utilizando sumas de ejemplo.

¡Por supuesto, esta no es la forma más eficiente de sumar dos números enteros! Pero el tutorial demuestra una técnica importante en ML web: cómo realizar cálculos de larga duración sin bloquear el hilo principal, que maneja la lógica de la interfaz de usuario.

La aplicación de ejemplo para este tutorial está disponible en línea , por lo que no necesita descargar ningún código ni configurar un entorno de desarrollo. Si desea ejecutar el código localmente, complete los pasos opcionales en Ejecutar el ejemplo localmente . Si no desea configurar un entorno de desarrollo, puede pasar a Explorar el ejemplo .

El código de ejemplo está disponible en GitHub .

(Opcional) Ejecute el ejemplo localmente

Requisitos previos

Para ejecutar la aplicación de ejemplo localmente, necesita tener instalado lo siguiente en su entorno de desarrollo:

Instalar y ejecutar la aplicación de ejemplo

  1. Clona o descarga el repositorio tfjs-examples .
  2. Cambie al directorio addition-rnn-webworker :

    cd tfjs-examples/addition-rnn-webworker
    
  3. Instalar dependencias:

    yarn
    
  4. Inicie el servidor de desarrollo:

    yarn run watch
    

Explora el ejemplo

Abra la aplicación de ejemplo . (O, si está ejecutando el ejemplo localmente, vaya a http://localhost:1234 en su navegador).

Debería ver una página titulada TensorFlow.js: Addition RNN . Siga las instrucciones para probar la aplicación.

Mediante el formulario web, puede actualizar algunos de los parámetros utilizados para entrenar el modelo, incluidos los siguientes:

  • Dígitos : El número máximo de dígitos en los términos a agregar.
  • Tamaño de entrenamiento : la cantidad de ejemplos de entrenamiento que se generarán.
  • Tipo de RNN : uno de SimpleRNN , GRU o LSTM .
  • Tamaño de capa oculta RNN : dimensionalidad del espacio de salida (debe ser un número entero positivo).
  • Tamaño de lote : número de muestras por actualización de gradiente.
  • Entrenar iteraciones : número de veces para entrenar el modelo invocando model.fit()
  • # de ejemplos de prueba : número de cadenas de ejemplo (por ejemplo, 27+41 ) para generar.

Intente entrenar el modelo con diferentes parámetros y vea si puede mejorar la precisión de las predicciones para varios conjuntos de dígitos. Observe también cómo el tiempo de ajuste del modelo se ve afectado por diferentes parámetros.

Explora el código

La aplicación de ejemplo muestra algunos de los parámetros que puede configurar para entrenar un RNN. También demuestra el uso de un trabajador web para entrenar un modelo fuera del hilo principal. Los trabajadores web son importantes en el aprendizaje automático web porque le permiten ejecutar tareas de capacitación computacionalmente costosas en un subproceso en segundo plano, evitando así problemas de rendimiento que puedan afectar al usuario en el subproceso principal. Los hilos principal y de trabajo se comunican entre sí a través de eventos de mensajes.

Para obtener más información sobre los trabajadores web, consulte API de trabajadores web y uso de trabajadores web .

El módulo principal de la aplicación de ejemplo es index.js . El script index.js crea un trabajador web que ejecuta el módulo worker.js :

const worker =
    new Worker(new URL('./worker.js', import.meta.url), {type: 'module'});

index.js se compone en gran medida de una única función, runAdditionRNNDemo , que maneja el envío del formulario, procesa los datos del formulario, pasa los datos del formulario al trabajador, espera a que el trabajador entrene el modelo y devuelva los resultados, y luego muestra los resultados en la página. .

Para enviar los datos del formulario al trabajador, el script invoca postMessage en el trabajador:

worker.postMessage({
  digits,
  trainingSize,
  rnnType,
  layers,
  hiddenSize,
  trainIterations,
  batchSize,
  numTestExamples
});

El trabajador escucha este mensaje y pasa los datos del formulario a funciones que preparan los datos e inician el entrenamiento:

self.addEventListener('message', async (e) => {
  const { digits, trainingSize, rnnType, layers, hiddenSize, trainIterations, batchSize, numTestExamples } = e.data;
  const demo = new AdditionRNNDemo(digits, trainingSize, rnnType, layers, hiddenSize);
  await demo.train(trainIterations, batchSize, numTestExamples);
})

Durante la capacitación, el trabajador puede enviar dos tipos de mensajes diferentes, uno con isPredict establecido en true

self.postMessage({
  isPredict: true,
  i, iterations, modelFitTime,
  lossValues, accuracyValues,
});

y el otro con isPredict establecido en false .

self.postMessage({
  isPredict: false,
  isCorrect, examples
});

Cuando el hilo de la interfaz de usuario ( index.js ) maneja eventos de mensajes, verifica el indicador isPredict para determinar la forma de los datos devueltos por el trabajador. Si isPredict es verdadero, los datos deben representar una predicción y el script actualiza la página usando tfjs-vis . Si isPredict es falso, el script ejecuta un bloque de código que asume que los datos representan ejemplos. Envuelve los datos en HTML e inserta el HTML en la página.

Que sigue

Este tutorial proporciona un ejemplo del uso de un trabajador web para evitar bloquear el hilo de la interfaz de usuario con un proceso de capacitación de larga duración. Para obtener más información sobre los beneficios de realizar cálculos costosos en un hilo en segundo plano, consulte Usar trabajadores web para ejecutar JavaScript desde el hilo principal del navegador .

Para obtener más información sobre cómo entrenar un modelo TensorFlow.js, consulte Modelos de entrenamiento .