Salva e carica modelli

TensorFlow.js fornisce funzionalità per salvare e caricare modelli che sono stati creati con l'API Layers o convertiti da modelli TensorFlow esistenti. Questi possono essere modelli che hai formato tu stesso o quelli formati da altri. Un vantaggio chiave dell'utilizzo dell'API Layers è che i modelli creati con essa sono serializzabili e questo è ciò che esploreremo in questo tutorial.

Questo tutorial si concentrerà sul salvataggio e sul caricamento dei modelli TensorFlow.js (identificabili dai file JSON). Possiamo anche importare modelli TensorFlow Python. Il caricamento di questi modelli è trattato nei due tutorial seguenti:

Salva un tf.Model

tf.Model e tf.Sequential forniscono entrambi una funzione model.save che consente di salvare la topologia e i pesi di un modello.

  • Topologia: questo è un file che descrive l'architettura di un modello (cioè quali operazioni utilizza). Contiene riferimenti ai pesi dei modelli memorizzati esternamente.

  • Pesi: si tratta di file binari che memorizzano i pesi di un determinato modello in un formato efficiente. Generalmente vengono archiviati nella stessa cartella della topologia.

Diamo un'occhiata a come appare il codice per salvare un modello

const saveResult = await model.save('localstorage://my-model-1');

Alcune cose da notare:

  • Il metodo save accetta un argomento stringa simile a un URL che inizia con uno schema . Questo descrive il tipo di destinazione in cui stiamo cercando di salvare un modello. Nell'esempio sopra lo schema è localstorage://
  • Lo schema è seguito da un percorso . Nell'esempio sopra il percorso è my-model-1 .
  • Il metodo save è asincrono.
  • Il valore restituito di model.save è un oggetto JSON che trasporta informazioni come le dimensioni in byte della topologia e dei pesi del modello.
  • L'ambiente utilizzato per salvare il modello non influisce sugli ambienti che possono caricare il modello. Il salvataggio di un modello in node.js non ne impedisce il caricamento nel browser.

Di seguito esamineremo i diversi schemi disponibili.

Archiviazione locale (solo browser)

Schema: localstorage://

await model.save('localstorage://my-model');

Ciò salva un modello con il nome my-model nella memoria locale del browser. Ciò persisterà tra gli aggiornamenti, sebbene l'archiviazione locale possa essere cancellata dagli utenti o dal browser stesso se lo spazio diventa un problema. Ciascun browser imposta inoltre il proprio limite sulla quantità di dati che possono essere archiviati nella memoria locale per un determinato dominio.

DB indicizzato (solo browser)

Schema: indexeddb://

await model.save('indexeddb://my-model');

Ciò salva un modello nell'archivio IndexedDB del browser. Come l'archiviazione locale, persiste tra un aggiornamento e l'altro, tende anche ad avere limiti maggiori sulla dimensione degli oggetti archiviati.

Download di file (solo browser)

Schema: downloads://

await model.save('downloads://my-model');

Ciò farà sì che il browser scarichi i file del modello sul computer dell'utente. Verranno prodotti due file:

  1. Un file JSON di testo denominato [my-model].json , che contiene la topologia e il riferimento al file dei pesi descritto di seguito.
  2. Un file binario che trasporta i valori del peso denominato [my-model].weights.bin .

Puoi cambiare il nome [my-model] per ottenere file con un nome diverso.

Poiché il file .json punta al .bin utilizzando un percorso relativo, i due file dovrebbero trovarsi nella stessa cartella.

Richiesta HTTP(S).

Schema: http:// o https://

await model.save('http://model-server.domain/upload')

Verrà creata una richiesta Web per salvare un modello su un server remoto. Dovresti avere il controllo di quel server remoto in modo da poter garantire che sia in grado di gestire la richiesta.

Il modello verrà inviato al server HTTP specificato tramite una richiesta POST . Il corpo del POST è nel formato multipart/form-data ed è composto da due file

  1. Un file JSON di testo denominato model.json , che contiene la topologia e il riferimento al file dei pesi descritto di seguito.
  2. Un file binario che trasporta i valori di peso denominato model.weights.bin .

Tieni presente che il nome dei due file sarà sempre esattamente come specificato sopra (il nome è integrato nella funzione). Questo documento API contiene uno snippet di codice Python che dimostra come è possibile utilizzare il framework web flask per gestire la richiesta originata da save .

Spesso dovrai passare più argomenti o richiedere intestazioni al tuo server HTTP (ad esempio per l'autenticazione o se desideri specificare una cartella in cui salvare il modello). Puoi ottenere un controllo capillare su questi aspetti delle richieste dal save sostituendo l'argomento della stringa URL in tf.io.browserHTTPRequest . Questa API offre una maggiore flessibilità nel controllo delle richieste HTTP.

Per esempio:

await model.save(tf.io.browserHTTPRequest(
    'http://model-server.domain/upload',
    {method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));

File system nativo (solo Node.js)

Schema: file://

await model.save('file:///path/to/my-model');

Quando eseguiamo su Node.js abbiamo anche accesso diretto al filesystem e possiamo salvare lì i modelli. Il comando precedente salverà due file nel path specificato dopo lo scheme .

  1. Un file JSON di testo denominato [model].json , che contiene la topologia e il riferimento al file dei pesi descritto di seguito.
  2. Un file binario che trasporta i valori di peso denominato [model].weights.bin .

Tieni presente che il nome dei due file sarà sempre esattamente come specificato sopra (il nome è integrato nella funzione).

Caricamento di un tf.Model

Dato un modello che è stato salvato utilizzando uno dei metodi sopra indicati, possiamo caricarlo utilizzando l'API tf.loadLayersModel .

Diamo un'occhiata a come appare il codice per caricare un modello

const model = await tf.loadLayersModel('localstorage://my-model-1');

Alcune cose da notare:

  • Come model.save() , la funzione loadLayersModel accetta un argomento stringa simile a un URL che inizia con uno schema . Questo descrive il tipo di destinazione da cui stiamo tentando di caricare un modello.
  • Lo schema è seguito da un percorso . Nell'esempio sopra il percorso è my-model-1 .
  • La stringa simile all'URL può essere sostituita da un oggetto che corrisponde all'interfaccia IOHandler.
  • La funzione tf.loadLayersModel() è asincrona.
  • Il valore restituito di tf.loadLayersModel è tf.Model

Di seguito esamineremo i diversi schemi disponibili.

Archiviazione locale (solo browser)

Schema: localstorage://

const model = await tf.loadLayersModel('localstorage://my-model');

Questo carica un modello denominato my-model dalla memoria locale del browser.

DB indicizzato (solo browser)

Schema: indexeddb://

const model = await tf.loadLayersModel('indexeddb://my-model');

Questo carica un modello dall'archivio IndexedDB del browser.

HTTP(S)

Schema: http:// o https://

const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');

Questo carica un modello da un endpoint http. Dopo aver caricato il file json , la funzione effettuerà richieste per i file .bin corrispondenti a cui fa riferimento il file json .

File system nativo (solo Node.js)

Schema: file://

const model = await tf.loadLayersModel('file://path/to/my-model/model.json');

Quando eseguiamo su Node.js abbiamo anche accesso diretto al filesystem e possiamo caricare i modelli da lì. Tieni presente che nella chiamata di funzione sopra facciamo riferimento al file model.json stesso (mentre durante il salvataggio specifichiamo una cartella). I file .bin corrispondenti dovrebbero trovarsi nella stessa cartella del file json .

Caricamento di modelli con IOHandlers

Se gli schemi di cui sopra non sono sufficienti per le tue esigenze, puoi implementare un comportamento di caricamento personalizzato con un IOHandler . Un IOHandler fornito da TensorFlow.js è tf.io.browserFiles che consente agli utenti del browser di caricare file di modello nel browser. Consulta la documentazione per ulteriori informazioni.

Salvataggio e caricamento di modelli con IOHandler personalizzati

Se gli schemi sopra riportati non sono sufficienti per le tue esigenze di caricamento o salvataggio, puoi implementare un comportamento di serializzazione personalizzato implementando un IOHandler .

Un IOHandler è un oggetto con un metodo save e load .

La funzione save accetta un parametro che corrisponde all'interfaccia ModelArtifacts e dovrebbe restituire una promessa che si risolve in un oggetto SaveResult .

La funzione load non accetta parametri e dovrebbe restituire una promessa che si risolve in un oggetto ModelArtifacts . Questo è lo stesso oggetto passato a save .

Vedi BrowserHTTPRequest per un esempio di come implementare un IOHandler.