Les modèles Keras (généralement créés via l'API Python) peuvent être enregistrés dans l'un des nombreux formats . Le format "modèle entier" peut être converti au format TensorFlow.js Layers, qui peut être chargé directement dans TensorFlow.js à des fins d'inférence ou de formation complémentaire.
Le format TensorFlow.js Layers cible est un répertoire contenant un fichier model.json
et un ensemble de fichiers de pondération fragmentés au format binaire. Le fichier model.json
contient à la fois la topologie du modèle (alias "architecture" ou "graphe" : une description des couches et la façon dont elles sont connectées) et un manifeste des fichiers de pondération.
Conditions
La procédure de conversion nécessite un environnement Python ; vous voudrez peut-être en garder un isolé en utilisant pipenv ou virtualenv . Pour installer le convertisseur, utilisez pip install tensorflowjs
.
L'importation d'un modèle Keras dans TensorFlow.js est un processus en deux étapes. Tout d'abord, convertissez un modèle Keras existant au format TF.js Layers, puis chargez-le dans TensorFlow.js.
Étape 1. Convertir un modèle Keras existant au format TF.js Layers
Les modèles Keras sont généralement enregistrés via model.save(filepath)
, qui produit un seul fichier HDF5 (.h5) contenant à la fois la topologie du modèle et les poids. Pour convertir un tel fichier au format TF.js Layers, exécutez la commande suivante, où path/to/my_model.h5
est le fichier source Keras .h5 et path/to/tfjs_target_dir
est le répertoire de sortie cible pour les fichiers TF.js :
# bash
tensorflowjs_converter --input_format keras \
path/to/my_model.h5 \
path/to/tfjs_target_dir
Alternative : utilisez l'API Python pour exporter directement au format TF.js Layers
Si vous avez un modèle Keras en Python, vous pouvez l'exporter directement au format TensorFlow.js Layers comme suit :
# Python
import tensorflowjs as tfjs
def train(...):
model = keras.models.Sequential() # for example
...
model.compile(...)
model.fit(...)
tfjs.converters.save_keras_model(model, tfjs_target_dir)
Étape 2 : Charger le modèle dans TensorFlow.js
Utilisez un serveur Web pour servir les fichiers de modèle convertis que vous avez générés à l'étape 1. Notez que vous devrez peut-être configurer votre serveur pour autoriser le partage de ressources cross-origin (CORS) , afin de permettre la récupération des fichiers en JavaScript.
Chargez ensuite le modèle dans TensorFlow.js en fournissant l'URL du fichier model.json :
// JavaScript
import * as tf from '@tensorflow/tfjs';
const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');
Le modèle est maintenant prêt pour l'inférence, l'évaluation ou le réentraînement. Par exemple, le modèle chargé peut être immédiatement utilisé pour faire une prédiction :
// JavaScript
const example = tf.fromPixels(webcamElement); // for example
const prediction = model.predict(example);
De nombreux exemples TensorFlow.js adoptent cette approche, en utilisant des modèles pré-entraînés qui ont été convertis et hébergés sur Google Cloud Storage.
Notez que vous faites référence au modèle entier en utilisant le nom de fichier model.json
. loadModel(...)
récupère model.json
, puis effectue des requêtes HTTP(S) supplémentaires pour obtenir les fichiers de poids partitionnés référencés dans le manifeste de poids model.json
. Cette approche permet à tous ces fichiers d'être mis en cache par le navigateur (et peut-être par des serveurs de mise en cache supplémentaires sur Internet), car le model.json
et les fragments de poids sont chacun plus petits que la limite de taille de fichier de cache typique. Ainsi, un modèle est susceptible de se charger plus rapidement lors d'occasions ultérieures.
Fonctionnalités prises en charge
TensorFlow.js Layers ne prend actuellement en charge que les modèles Keras utilisant des constructions Keras standard. Les modèles utilisant des opérations ou des couches non prises en charge (par exemple, des couches personnalisées, des couches Lambda, des pertes personnalisées ou des métriques personnalisées) ne peuvent pas être importés automatiquement, car ils dépendent d'un code Python qui ne peut pas être traduit de manière fiable en JavaScript.