Importazione di un modello Keras in TensorFlow.js

Modelli KERAS (tipicamente creati tramite l'API Python) possono essere salvate in uno dei diversi formati . Il formato "intero modello" può essere convertito in formato TensorFlow.js Layers, che può essere caricato direttamente in TensorFlow.js per l'inferenza o per ulteriore formazione.

Il formato Livelli TensorFlow.js di destinazione è una directory che contiene un model.json di file e una serie di file di peso sharded in formato binario. Il model.json file contiene sia la topologia del modello (alias "architettura" o "grafico": la descrizione dei livelli e come sono collegati) e di un manifesto di file di peso.

Requisiti

La procedura di conversione richiede un ambiente Python; si consiglia di mantenere un isolato uno utilizzando pipenv o virtualenv . Per installare il convertitore, l'uso pip install tensorflowjs .

L'importazione di un modello Keras in TensorFlow.js è un processo in due fasi. Innanzitutto, converti un modello Keras esistente in formato TF.js Layers, quindi caricalo in TensorFlow.js.

Passaggio 1. Converti un modello Keras esistente in formato TF.js Layers

Modelli Keras vengono solitamente salvate tramite model.save(filepath) , che produce un singolo file HDF5 (.h5) che contiene sia la topologia del modello e dei pesi. Per convertire un file in formato TF.js Livelli, eseguire il seguente comando, dove path/to/my_model.h5 è la fonte Keras .h5 file e il path/to/tfjs_target_dir è la directory di output di destinazione per i file di TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternativa: usa l'API Python per esportare direttamente nel formato TF.js Layers

Se hai un modello Keras in Python, puoi esportarlo direttamente nel formato TensorFlow.js Layers come segue:

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Passaggio 2: caricare il modello in TensorFlow.js

Utilizza un server web per servire i file del modello convertiti generate nel passaggio 1. Si noti che potrebbe essere necessario configurare il server per consentire la condivisione delle risorse Cross-Origin (CORS) , al fine di consentire il recupero dei file in JavaScript.

Quindi caricare il modello in TensorFlow.js fornendo l'URL al file model.json:

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Ora il modello è pronto per l'inferenza, la valutazione o il riaddestramento. Ad esempio, il modello caricato può essere immediatamente utilizzato per fare una previsione:

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

Molti dei TensorFlow.js esempi adottare questo approccio, utilizzando modelli preaddestrato che sono stati convertiti e ospitati su Google Cloud Storage.

Si noti che si fa riferimento l'intero modello utilizzando il model.json nome del file. loadModel(...) recupera model.json , e poi fa le richieste HTTP aggiuntiva (S) per ottenere i file di peso sharded riferimento nel model.json manifesta peso. Questo approccio permette di tutti questi file da memorizzare nella cache dal browser (e, forse, dai server di caching supplementari su internet), in quanto il model.json e le schegge di peso sono ogni più piccolo del limite tipico dimensioni del file di cache. Pertanto è probabile che un modello si carichi più rapidamente nelle occasioni successive.

Funzionalità supportate

TensorFlow.js Layers attualmente supporta solo i modelli Keras che utilizzano costrutti Keras standard. I modelli che utilizzano operazioni o livelli non supportati, ad esempio livelli personalizzati, livelli Lambda, perdite personalizzate o metriche personalizzate, non possono essere importati automaticamente, perché dipendono da codice Python che non può essere tradotto in modo affidabile in JavaScript.