Importowanie modeli opartych na TensorFlow GraphDef do TensorFlow.js

Modele oparte na TensorFlow GraphDef (zwykle tworzone za pomocą API Pythona) można zapisać w jednym z następujących formatów:

  1. TensorFlow SavedModel
  2. Model mrożony
  3. Moduł koncentratora Tensorflow

Wszystkie powyższe formaty mogą zostać przekształcone przez konwerter TensorFlow.js do formatu, który można załadować bezpośrednio do TensorFlow.js do wnioskowania.

(Uwaga: TensorFlow wycofał format pakietu sesji, prosimy o migrację modeli do formatu SavedModel).

Wymagania

Procedura konwersji wymaga środowiska Python; Może chcesz zachować odosobnione użyciu pipenv lub virtualenv . Aby zainstalować konwerter, uruchom następujące polecenie:

 pip install tensorflowjs

Importowanie modelu TensorFlow do TensorFlow.js to proces dwuetapowy. Najpierw przekonwertuj istniejący model do formatu internetowego TensorFlow.js, a następnie załaduj go do TensorFlow.js.

Krok 1. Konwertuj istniejący model TensorFlow do formatu internetowego TensorFlow.js

Uruchom skrypt konwertera dostarczony przez pakiet pip:

Sposób użycia: Przykład zapisanego modelu:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Przykład modelu zamrożonego:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Przykład modułu Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumenty pozycyjne Opis
input_path Pełna ścieżka zapisanego katalogu modelu, katalogu pakietu sesji, zamrożonego pliku modelu lub uchwytu lub ścieżki modułu TensorFlow Hub.
output_path Ścieżka do wszystkich artefaktów wyjściowych.
Opcje Opis
--input_format Format modelu wejściowego, użyj tf_saved_model dla SavedModel, tf_frozen_model dla modelu zamrożonego, tf_session_bundle dla pakietu sesji, tf_hub dla modułu TensorFlow Hub i keras dla Keras HDF5.
--output_node_names Nazwy węzłów wyjściowych oddzielone przecinkami.
--saved_model_tags Dotyczy tylko konwersji SavedModel, tagów MetaGraphDef do załadowania, w formacie rozdzielanym przecinkami. Domyślnie serve .
--signature_name Dotyczy tylko konwersji modułu TensorFlow Hub, podpisu do załadowania. Domyślnie default . Zobacz https://www.tensorflow.org/hub/common_signatures/

Użyj następującego polecenia, aby uzyskać szczegółową wiadomość pomocy:

tensorflowjs_converter --help

Pliki wygenerowane przez konwerter

Powyższy skrypt konwersji tworzy dwa typy plików:

  • model.json (wykres przepływu danych i wagi oczywisty)
  • group1-shard\*of\* (zbiór plików binarnych wagi)

Na przykład, oto wynik konwersji MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Krok 2: Ładowanie i uruchamianie w przeglądarce

  1. Zainstaluj pakiet tfjs-converter npm

yarn add @tensorflow/tfjs lub npm install @tensorflow/tfjs

  1. Instancję klasy FrozenModel i uruchomić wnioskowanie.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Sprawdź nasze demo MobileNet .

loadGraphModel API akceptuje dodatkowe LoadOptions parametr, który może być używany do wysyłania poświadczeń lub niestandardowe nagłówki wraz z wnioskiem. Proszę zobaczyć loadGraphModel () dokumentację po więcej szczegółów.

Obsługiwane operacje

Obecnie TensorFlow.js obsługuje ograniczony zestaw operacji TensorFlow. Jeśli model wykorzystuje nieobsługiwany op The tensorflowjs_converter skrypt nie zadziała i wydrukować listę nieobsługiwanych ops w modelu. Proszę złożyć problem dla każdego op daj nam znać które ops Need You wsparcie.

Ładowanie samych ciężarków

Jeśli wolisz załadować tylko wagi, możesz użyć poniższego fragmentu kodu.

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");