Impor model TensorFlow ke TensorFlow.js

Model berbasis TensorFlow GraphDef (biasanya dibuat melalui Python API) dapat disimpan dalam salah satu format berikut:

  1. Model Tersimpan TensorFlow
  2. Model Beku
  3. Modul Tensorflow Hub

Semua format di atas dapat dikonversi oleh konverter TensorFlow.js menjadi format yang dapat dimuat langsung ke TensorFlow.js untuk inferensi.

(Catatan: TensorFlow tidak lagi menggunakan format paket sesi. Harap migrasikan model Anda ke format SavedModel.)

Persyaratan

Prosedur konversi memerlukan lingkungan Python; Anda mungkin ingin menyimpan yang terisolasi menggunakan pipenv atau virtualenv .

Untuk menginstal konverter, jalankan perintah berikut:

 pip install tensorflowjs

Mengimpor model TensorFlow ke TensorFlow.js memerlukan proses dua langkah. Pertama, konversikan model yang ada ke format web TensorFlow.js, lalu muat model tersebut ke TensorFlow.js.

Langkah 1. Konversikan model TensorFlow yang ada ke format web TensorFlow.js

Jalankan skrip konverter yang disediakan oleh paket pip:

Contoh Model Tersimpan:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Contoh model beku:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Contoh modul Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Argumen Posisi Keterangan
input_path Jalur lengkap direktori model yang disimpan, direktori bundel sesi, file model yang dibekukan, atau jalur atau pegangan modul TensorFlow Hub.
output_path Jalur untuk semua artefak keluaran.
Pilihan Keterangan
--input_format Format model masukan. Gunakan tf_saved_model untuk SavedModel, tf_frozen_model untuk model beku, tf_session_bundle untuk bundel sesi, tf_hub untuk modul TensorFlow Hub dan keras untuk Keras HDF5.
--output_node_names Nama node keluaran, dipisahkan dengan koma.
--saved_model_tags Hanya berlaku untuk konversi SavedModel. Tag MetaGraphDef yang akan dimuat, dalam format yang dipisahkan koma. Default untuk serve .
--signature_name Hanya berlaku untuk konversi modul TensorFlow Hub, tanda tangan untuk dimuat. Defaultnya ke default . Lihat https://www.tensorflow.org/hub/common_signatures/

Gunakan perintah berikut untuk mendapatkan pesan bantuan terperinci:

tensorflowjs_converter --help

Konverter file yang dihasilkan

Skrip konversi di atas menghasilkan dua jenis file:

  • model.json : Grafik aliran data dan manifes bobot
  • group1-shard\*of\* : Kumpulan file berat biner

Misalnya, berikut adalah keluaran dari konversi MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Langkah 2: Memuat dan menjalankan di browser

  1. Instal paket npm tfjs-converter:

yarn add @tensorflow/tfjs atau npm install @tensorflow/tfjs

  1. Buat instance kelas FrozenModel dan jalankan inferensi.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Lihat demo MobileNet .

API loadGraphModel menerima parameter LoadOptions tambahan, yang dapat digunakan untuk mengirim kredensial atau header khusus bersama dengan permintaan. Untuk detailnya, lihat dokumentasi loadGraphModel() .

Operasi yang didukung

Saat ini TensorFlow.js mendukung serangkaian operasi TensorFlow yang terbatas. Jika model Anda menggunakan operasi yang tidak didukung, skrip tensorflowjs_converter akan gagal dan mencetak daftar operasi yang tidak didukung dalam model Anda. Silakan ajukan masalah untuk setiap operasi agar kami tahu operasi mana yang memerlukan dukungan.

Memuat beban saja

Jika Anda lebih suka memuat bobot saja, Anda dapat menggunakan cuplikan kode berikut:

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...