โมเดลที่ใช้ TensorFlow GraphDef (โดยทั่วไปจะสร้างผ่าน Python API) สามารถบันทึกในรูปแบบใดรูปแบบหนึ่งต่อไปนี้:
- โมเดลที่บันทึกไว้ ของ TensorFlow
- โมเดลโฟรเซ่น
- โมดูลฮับเทนเซอร์โฟลว์
รูปแบบข้างต้นทั้งหมดสามารถแปลงได้โดย ตัวแปลง TensorFlow.js ให้เป็นรูปแบบที่สามารถโหลดลงใน TensorFlow.js ได้โดยตรงเพื่อการอนุมาน
(หมายเหตุ: TensorFlow เลิกใช้รูปแบบบันเดิลเซสชันแล้ว โปรดย้ายโมเดลของคุณเป็นรูปแบบ SavedModel)
ความต้องการ
ขั้นตอนการแปลงต้องใช้สภาพแวดล้อม Python คุณอาจต้องการแยกอันหนึ่งโดยใช้ pipenv หรือ virtualenv
ในการติดตั้งตัวแปลงให้รันคำสั่งต่อไปนี้:
pip install tensorflowjs
การนำเข้าโมเดล TensorFlow ลงใน TensorFlow.js นั้นเป็นกระบวนการที่มีสองขั้นตอน ขั้นแรก แปลงโมเดลที่มีอยู่เป็นรูปแบบเว็บ TensorFlow.js จากนั้นโหลดลงใน TensorFlow.js
ขั้นตอนที่ 1 แปลงโมเดล TensorFlow ที่มีอยู่เป็นรูปแบบเว็บ TensorFlow.js
รันสคริปต์ตัวแปลงที่จัดทำโดยแพ็คเกจ pip:
ตัวอย่างรุ่นที่บันทึกไว้:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
ตัวอย่างโมเดล Frozen:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
ตัวอย่างโมดูล Tensorflow Hub:
tensorflowjs_converter \
--input_format=tf_hub \
'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
อาร์กิวเมนต์ตำแหน่ง | คำอธิบาย |
---|---|
input_path | เส้นทางแบบเต็มของไดเร็กทอรีโมเดลที่บันทึกไว้ ไดเร็กทอรีบันเดิลเซสชัน ไฟล์โมเดลที่แช่แข็ง หรือตัวจัดการหรือเส้นทางของโมดูล TensorFlow Hub |
output_path | เส้นทางสำหรับอาร์ติแฟกต์เอาต์พุตทั้งหมด |
ตัวเลือก | คำอธิบาย |
---|---|
--input_format | รูปแบบของโมเดลอินพุต ใช้ tf_saved_model สำหรับ SavedModel, tf_frozen_model สำหรับโมเดล Frozen, tf_session_bundle สำหรับบันเดิลเซสชัน, tf_hub สำหรับโมดูล TensorFlow Hub และ keras สำหรับ Keras HDF5 |
--output_node_names | ชื่อของโหนดเอาต์พุต คั่นด้วยเครื่องหมายจุลภาค |
--saved_model_tags | ใช้ได้กับการแปลง SavedModel เท่านั้น แท็กของ MetaGraphDef ที่จะโหลด ในรูปแบบที่คั่นด้วยเครื่องหมายจุลภาค ค่าเริ่มต้นที่จะ serve |
--signature_name | ใช้ได้กับการแปลงโมดูล TensorFlow Hub เท่านั้น ลายเซ็นที่จะโหลด ค่าเริ่มต้นเป็น default ดู https://www.tensorflow.org/hub/common_signatures/ |
ใช้คำสั่งต่อไปนี้เพื่อรับข้อความช่วยเหลือโดยละเอียด:
tensorflowjs_converter --help
ไฟล์ที่สร้างจากตัวแปลง
สคริปต์การแปลงด้านบนจะสร้างไฟล์สองประเภท:
-
model.json
: กราฟกระแสข้อมูลและรายการน้ำหนัก -
group1-shard\*of\*
: ชุดของไฟล์น้ำหนักไบนารี
ตัวอย่างเช่น นี่คือผลลัพธ์จากการแปลง MobileNet v2:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
ขั้นตอนที่ 2: กำลังโหลดและทำงานในเบราว์เซอร์
- ติดตั้งแพ็คเกจ tfjs-converter npm:
yarn add @tensorflow/tfjs
หรือ npm install @tensorflow/tfjs
- สร้างอินสแตนซ์ คลาส FrozenModel และเรียกใช้การอนุมาน
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));
ลองชม การสาธิต MobileNet
loadGraphModel
API ยอมรับพารามิเตอร์ LoadOptions
เพิ่มเติม ซึ่งสามารถใช้เพื่อส่งข้อมูลประจำตัวหรือส่วนหัวที่กำหนดเองไปพร้อมกับคำขอ สำหรับรายละเอียด โปรดดู เอกสารประกอบ loadGraphModel()
การดำเนินงานที่รองรับ
ปัจจุบัน TensorFlow.js รองรับชุดการดำเนินการ TensorFlow ที่จำกัด หากโมเดลของคุณใช้ ops ที่ไม่รองรับ สคริปต์ tensorflowjs_converter
จะล้มเหลวและพิมพ์รายการ ops ที่ไม่รองรับในโมเดลของคุณ โปรดแจ้ง ปัญหา สำหรับปฏิบัติการแต่ละรายการเพื่อแจ้งให้เราทราบว่าคุณต้องการการสนับสนุนสำหรับปฏิบัติการใด
โหลดน้ำหนักเท่านั้น
หากคุณต้องการโหลดเฉพาะน้ำหนัก คุณสามารถใช้ข้อมูลโค้ดต่อไปนี้:
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");
// Use `weightMap` ...