Importer un modèle Keras dans TensorFlow.js

Modèles KERAS (généralement créés via l'API Python) peuvent être enregistrés dans l' un des formats . Le format "tout le modèle" peut être converti au format TensorFlow.js Layers, qui peut être chargé directement dans TensorFlow.js pour l'inférence ou pour une formation complémentaire.

Le TensorFlow.js cible format couches est un répertoire contenant un model.json fichier et un ensemble de fichiers de poids au format binaire fragmentées. Le model.json fichier contient à la fois la topologie du modèle (alias « architecture » ou « graphique »: une description des couches et la façon dont ils sont connectés) et un manifeste des fichiers de poids.

Exigences

La procédure de conversion nécessite un environnement Python ; vous pouvez garder un cas isolé en utilisant pipenv ou virtualenv . Pour installer le convertisseur, utilisez pip install tensorflowjs .

L'importation d'un modèle Keras dans TensorFlow.js est un processus en deux étapes. Tout d'abord, convertissez un modèle Keras existant au format Layers TF.js, puis chargez-le dans TensorFlow.js.

Étape 1. Convertissez un modèle Keras existant au format TF.js Layers

Modèles KERAS sont généralement enregistrés par model.save(filepath) , qui produit un seul fichier HDF5 (.h5) contenant à la fois la topologie du modèle et les poids. Pour convertir un tel fichier au format TF.js couches, exécutez la commande suivante, où path/to/my_model.h5 est le fichier de la source Keras et path/to/tfjs_target_dir est le répertoire de sortie cible pour les fichiers TF.js:

# bash

tensorflowjs_converter --input_format keras \
                       path/to/my_model.h5 \
                       path/to/tfjs_target_dir

Alternative : utilisez l'API Python pour exporter directement au format TF.js Layers

Si vous disposez d'un modèle Keras en Python, vous pouvez l'exporter directement au format TensorFlow.js Layers comme suit :

# Python

import tensorflowjs as tfjs

def train(...):
    model = keras.models.Sequential()   # for example
    ...
    model.compile(...)
    model.fit(...)
    tfjs.converters.save_keras_model(model, tfjs_target_dir)

Étape 2 : chargez le modèle dans TensorFlow.js

Utiliser un serveur Web pour servir les fichiers convertis modèle que vous avez généré à l' étape 1. Notez que vous devrez peut - être configurer votre serveur pour autoriser le partage des ressources Cross-Origin (CORS) , afin de permettre la récupération des fichiers JavaScript.

Chargez ensuite le modèle dans TensorFlow.js en fournissant l'URL du fichier model.json :

// JavaScript

import * as tf from '@tensorflow/tfjs';

const model = await tf.loadLayersModel('https://foo.bar/tfjs_artifacts/model.json');

Le modèle est maintenant prêt pour l'inférence, l'évaluation ou le recyclage. Par exemple, le modèle chargé peut être immédiatement utilisé pour faire une prédiction :

// JavaScript

const example = tf.fromPixels(webcamElement);  // for example
const prediction = model.predict(example);

La plupart des TensorFlow.js exemples cette approche, en utilisant des modèles pré - entraîné qui ont été convertis et hébergés sur Google Cloud Storage.

Notez que vous faites référence au modèle entier à l' aide du model.json nom de fichier. loadModel(...) va chercher model.json , et fait alors des requêtes HTTP supplémentaires (S) pour obtenir les fichiers de poids référencés dans le fragmentées model.json manifeste de poids. Cette approche permet par le navigateur (et peut - être par les serveurs de mise en cache supplémentaires sur Internet) tous ces fichiers à être mis en cache, parce que le model.json et les tessons de poids sont chacun inférieur à la limite de la taille du fichier cache typique. Ainsi, un modèle est susceptible de se charger plus rapidement lors d'occasions ultérieures.

Fonctionnalités prises en charge

TensorFlow.js Layers ne prend actuellement en charge que les modèles Keras utilisant des constructions Keras standard. Les modèles utilisant des opérations ou des couches non prises en charge, par exemple des couches personnalisées, des couches Lambda, des pertes personnalisées ou des métriques personnalisées, ne peuvent pas être importés automatiquement, car ils dépendent du code Python qui ne peut pas être traduit de manière fiable en JavaScript.